In [None]:
import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim import Optimizer
import matplotlib.pyplot as plt

In [None]:
# 1. SAM Optimizer Implementation
class SAM(Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        if rho < 0.0:
            raise ValueError(f"Invalid rho: {rho}")
        defaults = dict(rho=rho, **kwargs)
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.eps = 1e-12
    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        scale = self.param_groups[0]['rho'] / (grad_norm + self.eps)
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = p.grad * scale
                p.add_(e_w)
                self.state[p]['e_w'] = e_w
        if zero_grad:
            self.zero_grad()
    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]['e_w'])
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    def step(self, closure=None):
        assert closure is not None, "SAM requires closure"
        closure().backward()
        self.first_step(zero_grad=True)
        closure().backward()
        self.second_step(zero_grad=True)

    def _grad_norm(self):
        device = self.param_groups[0]['params'][0].device
        norms = [p.grad.norm(p=2).to(device)
                 for group in self.param_groups for p in group['params'] if p.grad is not None]
        return torch.norm(torch.stack(norms), p=2)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import heapq

# --- 1. Prioritized Replay Memory ---
class PrioritizedReplayMemory:
    def __init__(self, capacity, alpha=0.6):
        self.capacity = capacity
        self.alpha = alpha
        self.buffer = []
        self.priorities = []
        self.pos = 0

    def add(self, experience, error):
        priority = (abs(error) + 1e-5) ** self.alpha
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
            self.priorities.append(priority)
        else:
            self.buffer[self.pos] = experience
            self.priorities[self.pos] = priority
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        probs = np.array(self.priorities) ** self.alpha
        probs /= probs.sum()
        
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]
        
        weights = (len(self.buffer) * probs[indices]) ** (-beta)
        weights /= weights.max()
        
        return samples, indices, weights

    def update_priority(self, idx, error):
        self.priorities[idx] = (abs(error) + 1e-5) ** self.alpha

# --- 2. Modified DQN Agent ---
class DQNAgent:
    def __init__(self, state_dim=512, action_dim=10,  
                 gamma=0.99, epsilon_start=1.0, epsilon_min=0.01, epsilon_decay=0.995):
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.memory = PrioritizedReplayMemory(10000)
        self.learn_step = 0
        self.action_dim = action_dim  # Thêm thuộc tính action_dim

        # Q-networks
        layers = [nn.Linear(state_dim, 64), nn.ReLU(),
                  nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, action_dim)]  # Sử dụng action_dim
        self.policy_net = nn.Sequential(*layers)
        self.target_net = nn.Sequential(*layers)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=1e-3)
        self.model = self.policy_net  # Alias

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_dim)  # Sử dụng self.action_dim
        with torch.no_grad():
            q = self.policy_net(torch.FloatTensor(state).unsqueeze(0))
        return int(q.argmax(1).item())

    # Giữ nguyên các hàm remember và learn từ code của bạn

# --- 3. Integration với ResNet và SAM ---


class FeatureExtractor(nn.Module):
    def __init__(self, pretrained: bool = False):      # ➊  mặc định = False
        super().__init__()

        # ➋  Nếu KHÔNG muốn pre-train → weights=None
        backbone = models.resnet18(
            weights=(
                models.ResNet18_Weights.DEFAULT
                if pretrained else None               #  ← trọng số ngẫu nhiên
            )
        )

        # ➌  Bỏ lớp FC và lưu feature_dim
        self.features = nn.Sequential(*list(backbone.children())[:-1])
        self.feature_dim = backbone.fc.in_features     # =512

    def forward(self, x):
        x = self.features(x)          # [B, 512, 1, 1]
        return x.view(x.size(0), -1)  # [B, 512]


def initialize_system():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Khởi tạo ResNet với feature extractor
    resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
    feature_extractor = FeatureExtractor(resnet).to(device)
    
    # Khởi tạo DQN Agent
    dqn_agent = DQNAgent(
        state_dim=feature_extractor.feature_dim,
        action_dim=10,  # Số lớp trong CIFAR-10
        gamma=0.99,
        epsilon_start=1.0,
        epsilon_min=0.01,
        epsilon_decay=0.995
    ).to(device)
    
    # Khởi tạo SAM Optimizer cho ResNet
    base_optimizer = optim.SGD
    sam_optimizer = SAM(
        feature_extractor.parameters(),
        base_optimizer,
        lr=0.1,
        momentum=0.9
    )
    
    return feature_extractor, dqn_agent, sam_optimizer, device

# --- 4. Training Loop với tích hợp DQN ---
def train_epoch(feature_extractor, dqn_agent, sam_optimizer, train_loader, criterion, device):
    feature_extractor.train()
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # SAM first step
        features = feature_extractor(inputs)
        loss = criterion(features, targets)  # Giả sử có classifier phù hợp
        loss.backward()
        sam_optimizer.first_step(zero_grad=True)
        
        # SAM second step
        criterion(feature_extractor(inputs), targets).backward()
        sam_optimizer.second_step(zero_grad=True)
        
        # DQN Experience Collection
        state = features[0].detach().cpu().numpy()
        action = dqn_agent.act(state)
        next_state = feature_extractor(inputs)[0].detach().cpu().numpy()
        reward = 1.0 if torch.argmax(features[0]) == targets[0] else -1.0
        done = batch_idx == len(train_loader) - 1
        
        dqn_agent.remember(state, action, reward, next_state, done)
        dqn_agent.learn()

In [None]:
from torch.utils.data import DataLoader, random_split  # <-- thêm random_split vào đây



In [None]:
# 5. Model Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
model.to(device)
criterion = nn.CrossEntropyLoss()
opt_sgd = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
opt_sam = SAM(model.parameters(), optim.SGD, lr=0.1, momentum=0.9, rho=0.05)
agent = DQNAgent(state_dim=2, action_dim=2)

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Định nghĩa transforms cho tập huấn luyện và tập validation
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                         std=(0.2023, 0.1994, 0.2010))
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                         std=(0.2023, 0.1994, 0.2010))
])

# Tạo dataset từ thư mục train (CIFAR-10 training set)
train_dir = "/kaggle/input/cifar10/cifar10/train"
full_dataset = datasets.ImageFolder(root=train_dir)

# Chia dataset thành 90% train và 10% val
dataset_size = len(full_dataset)
train_size = int(0.9 * dataset_size)
val_size = dataset_size - train_size
train_subset, val_subset = random_split(full_dataset, [train_size, val_size])
train_indices = train_subset.indices
val_indices = val_subset.indices

# Tạo dataset cho train và val với các transform đã định nghĩa
train_dataset = torch.utils.data.Subset(
    datasets.ImageFolder(root=train_dir, transform=train_transform),
    train_indices
)
val_dataset = torch.utils.data.Subset(
    datasets.ImageFolder(root=train_dir, transform=val_transform),
    val_indices
)

# Tạo DataLoader cho tập train và val
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                          num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False,
                        num_workers=0, pin_memory=True)

# In số lượng mẫu trong mỗi tập
print(f"Số lượng ảnh train: {len(train_dataset)}")
print(f"Số lượng ảnh val: {len(val_dataset)}")


In [None]:
# 6. Training & Validation Functions
def train_one_epoch(model, optimizer, loader, device):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for imgs, lbls in loader:
        imgs, lbls = imgs.to(device), lbls.to(device)
        # Sharpness-Aware Minimization (SAM) requires closure
        if isinstance(optimizer, SAM):
            # define closure for SAM
            def closure():
                optimizer.zero_grad()
                out = model(imgs)
                loss = criterion(out, lbls)
                return loss
            # first forward-backward and SAM step
            loss = closure()
            optimizer.step(closure)
            # recompute outputs for stats
            with torch.no_grad():
                out = model(imgs)
        else:
            optimizer.zero_grad()
            out = model(imgs)
            loss = criterion(out, lbls)
            loss.backward()
            optimizer.step()
        # accumulate metrics
        total_loss += loss.item() * lbls.size(0)
        preds = out.argmax(1)
        correct += (preds == lbls).sum().item()
        total += lbls.size(0)
    return total_loss/total, 100 * correct/total


def validate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            out = model(imgs)
            correct += (out.argmax(1) == lbls).sum().item()
            total += lbls.size(0)

In [None]:
import torch, torch.nn as nn, torch.optim as optim
import numpy as np, random
from collections import deque

class DQNAgent:
    def __init__(self, state_dim, action_dim,
                 gamma=0.99, epsilon=1.0,
                 epsilon_min=0.01, epsilon_decay=0.995):
        self.state_dim     = state_dim
        self.action_dim    = action_dim
        self.gamma         = gamma
        self.epsilon       = epsilon
        self.epsilon_min   = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.memory        = deque(maxlen=10_000)

        self.policy_net = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, 128),       nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        self.target_net  = nn.Sequential(*[layer for layer in self.policy_net])
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer   = optim.Adam(self.policy_net.parameters(), lr=1e-3)
        self.learn_step  = 0

    # ---- hành động ε-greedy -----------------------------------------------
    def act(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        state = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            q = self.policy_net(state)
        return q.argmax(1).item()

    # ---- lưu trải nghiệm ----------------------------------------------------
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    # ---- học ----------------------------------------------------------------
    def learn(self, batch_size=32):
        if len(self.memory) < batch_size:
            return 0.0

        batch      = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))

        states      = torch.as_tensor(states,      dtype=torch.float32)
        next_states = torch.as_tensor(next_states, dtype=torch.float32)
        actions     = torch.as_tensor(actions,     dtype=torch.long)
        rewards     = torch.as_tensor(rewards,     dtype=torch.float32)
        dones       = torch.as_tensor(dones,       dtype=torch.float32)

        q_pred = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            q_next = self.target_net(next_states).max(1)[0]
            q_trg  = rewards + self.gamma * q_next * (1 - dones)

        loss = nn.functional.mse_loss(q_pred, q_trg)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.learn_step % 100 == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        self.learn_step += 1
        return loss.item()


In [None]:
class Classifier(nn.Module):
    def __init__(self, feature_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def prepare_data(
    train_batch_size: int = 128,
    val_batch_size: int = 256,
    num_workers: int = 2,
    pin_memory: bool = True
):
    """
    Tải CIFAR-10, áp dụng transform và trả về train_loader, val_loader.
    train_batch_size: batch size cho training (mặc định 128)
    val_batch_size: batch size cho validation/test (mặc định 256)
    num_workers: số worker cho DataLoader
    pin_memory: nếu True sẽ pin_memory để tối ưu copy lên GPU
    """
    # Transforms
    train_tf = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))
    ])
    val_tf = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))
    ])

    # Datasets
    train_set = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=train_tf
    )
    val_set = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=val_tf
    )

    # DataLoaders
    train_loader = DataLoader(
        train_set,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    val_loader = DataLoader(
        val_set,
        batch_size=val_batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    return train_loader, val_loader


In [None]:
from __future__ import annotations
import time
import torch
import torch.nn as nn
import torch.optim as optim

# Giả sử đã import:
# from feature_extractor import FeatureExtractor
# from classifier       import Classifier
# from dqn_agent        import DQNAgent

# Import DataLoader và dataset trực tiếp
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize

# Cấu hình hệ số phạt thời gian cho reward
W_TIME = 0.05

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # -------- 1. Mô hình ----------
    feature_extractor = FeatureExtractor(pretrained=False).to(device)
    classifier        = Classifier(feature_dim=512, num_classes=10).to(device)
    params = list(feature_extractor.parameters()) + list(classifier.parameters())

    # -------- 2. Hai optimizer tách biệt ----------
    opt_sgd = torch.optim.SGD(params, lr=0.1, momentum=0.9)
    opt_sam = SAM(params, torch.optim.SGD, lr=0.1, momentum=0.9)

    criterion = nn.CrossEntropyLoss()
    dqn_agent = DQNAgent(state_dim=512, action_dim=2)  # 0:SGD, 1:SAM

    # -------- 3. Chuẩn bị dữ liệu CIFAR-10 --------
    transform = Compose([Resize((32, 32)), ToTensor()])
    train_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
    val_dataset   = CIFAR10(root="./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=128,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    test_loader = DataLoader(
        val_dataset,
        batch_size=256,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    NUM_EPOCHS, best_acc = 100, 0.0
    results = {k: [] for k in [
        "reward_history", "train_loss_history",
        "train_acc_history", "val_acc_history", "dqn_loss_history"
    ]}

    for epoch in range(NUM_EPOCHS):
        # ----- reset thống kê -----
        epoch_start    = time.time()
        total_reward   = 0.0
        total_dqn_loss = 0.0
        train_loss_sum = 0.0
        train_correct  = 0
        train_seen     = 0
        feature_extractor.train(); classifier.train()

        # Khởi tạo state và biến lưu action cuối
        state = None
        last_action_str = ""

        # Duyệt qua từng batch
        for batch_idx, (inputs, targets) in enumerate(train_loader, start=1):
            batch_start = time.time()

            inputs, targets = inputs.to(device), targets.to(device)
            features = feature_extractor(inputs)  # [B,512]
            outputs  = classifier(features)
            loss     = criterion(outputs, targets)

            # ---- Xây state ban đầu nếu chưa có ----
            if state is None:
                state = features[0].detach().cpu().numpy()

            # =============== DQN chọn optimizer ===============
            action = dqn_agent.act(state)  # 0→SGD, 1→SAM
            last_action_str = "SGD" if action == 0 else "SAM"
            print(
                f"Epoch {epoch+1:03}/{NUM_EPOCHS} - "
                f"Batch {batch_idx:03}/{len(train_loader)} - "
                f"Action: {last_action_str}"
            )

            # Áp dụng optimiser
            if action == 0:
                opt_sgd.zero_grad()
                loss.backward()
                opt_sgd.step()
            else:
                loss.backward()
                opt_sam.first_step(zero_grad=True)
                outputs2 = classifier(feature_extractor(inputs))
                criterion(outputs2, targets).backward()
                opt_sam.second_step(zero_grad=True)

            # Tính elapsed và reward
            batch_elapsed = time.time() - batch_start
            reward        = -loss.item() - W_TIME * batch_elapsed

            # Cập nhật DQN
            next_state = feature_extractor(inputs)[0].detach().cpu().numpy()
            dqn_agent.remember(state, action, reward, next_state, False)
            dqn_loss = dqn_agent.learn()

            # Cập nhật state
            state = next_state

            # ----- cộng dồn -----
            total_reward   += reward
            total_dqn_loss += dqn_loss
            train_loss_sum += loss.item() * targets.size(0)
            train_correct  += (outputs.argmax(1) == targets).sum().item()
            train_seen     += targets.size(0)

        # ----- thống kê cuối epoch -----
        train_loss = train_loss_sum / train_seen
        train_acc  = 100. * train_correct / train_seen
        val_acc    = validate(feature_extractor, classifier, test_loader, device)
        epoch_elapsed = time.time() - epoch_start

        # Lưu kết quả
        results["reward_history"].append(total_reward)
        results["dqn_loss_history"].append(total_dqn_loss / max(len(train_loader), 1))
        results["train_loss_history"].append(train_loss)
        results["train_acc_history"].append(train_acc)
        results["val_acc_history"].append(val_acc)

        # In summary kèm optimiser cuối cùng
        print(
            f"Epoch {epoch+1:03}/{NUM_EPOCHS} | "
            f"LastOpt {last_action_str} | "
            f"Reward {total_reward:8.1f} | "
            f"TrainLoss {train_loss:.4f} | "
            f"TrainAcc {train_acc:6.2f}% | "
            f"ValAcc {val_acc:6.2f}% | "
            f"DQN_Loss {results['dqn_loss_history'][-1]:.4f} | "
            f"Time {epoch_elapsed:.1f}s"
        )

        # Lưu model tốt nhất
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                "feature_extractor": feature_extractor.state_dict(),
                "classifier": classifier.state_dict()
            }, "best_model.pth")

if __name__ == "__main__":
    main()


In [None]:
from __future__ import annotations
import time
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Giả sử đã import:
# from feature_extractor import FeatureExtractor
# from classifier       import Classifier
# from dqn_agent        import DQNAgent
# from data_utils       import validate, SAM

# Import DataLoader và dataset trực tiếp
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize

# Import confusion matrix tools
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Cấu hình hệ số phạt thời gian cho reward
W_TIME = 0.05


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # -------- 1. Mô hình ----------
    feature_extractor = FeatureExtractor(pretrained=False).to(device)
    classifier        = Classifier(feature_dim=512, num_classes=10).to(device)
    params = list(feature_extractor.parameters()) + list(classifier.parameters())

    # -------- 2. Hai optimizer tách biệt ----------
    opt_sgd = torch.optim.SGD(params, lr=0.1, momentum=0.9)
    opt_sam = SAM(params, torch.optim.SGD, lr=0.1, momentum=0.9)

    criterion = nn.CrossEntropyLoss()
    dqn_agent = DQNAgent(state_dim=512, action_dim=2)  # 0:SGD, 1:SAM

    # -------- 3. Chuẩn bị dữ liệu CIFAR-10 --------
    transform = Compose([Resize((32, 32)), ToTensor()])
    train_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
    val_dataset   = CIFAR10(root="./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

    NUM_EPOCHS, best_acc = 100, 0.0
    results = {
        "reward_history": [],
        "train_loss_history": [],
        "train_acc_history": [],
        "val_acc_history": [],
        "dqn_loss_history": [],
        "time_history": []
    }

    # Train loop
    for epoch in range(NUM_EPOCHS):
        epoch_start    = time.time()
        total_reward   = 0.0
        total_dqn_loss = 0.0
        train_loss_sum = 0.0
        train_correct  = 0
        train_seen     = 0
        feature_extractor.train(); classifier.train()

        state = None
        last_action_str = ""

        for batch_idx, (inputs, targets) in enumerate(train_loader, start=1):
            batch_start = time.time()
            inputs, targets = inputs.to(device), targets.to(device)

            features = feature_extractor(inputs)
            outputs  = classifier(features)
            loss     = criterion(outputs, targets)

            if state is None:
                state = features[0].detach().cpu().numpy()

            # Chọn optimizer và thực hiện bước cập nhật
            action = dqn_agent.act(state)
            last_action_str = "SGD" if action == 0 else "SAM"
            if action == 0:
                opt_sgd.zero_grad(); loss.backward(); opt_sgd.step()
            else:
                loss.backward(); opt_sam.first_step(zero_grad=True)
                outputs2 = classifier(feature_extractor(inputs))
                criterion(outputs2, targets).backward(); opt_sam.second_step(zero_grad=True)

            batch_elapsed = time.time() - batch_start
            reward        = -loss.item() - W_TIME * batch_elapsed

            next_state = feature_extractor(inputs)[0].detach().cpu().numpy()
            dqn_agent.remember(state, action, reward, next_state, False)
            dqn_loss = dqn_agent.learn()
            state = next_state

            total_reward   += reward
            total_dqn_loss += dqn_loss
            train_loss_sum += loss.item() * targets.size(0)
            train_correct  += (outputs.argmax(1) == targets).sum().item()
            train_seen     += targets.size(0)

        train_loss = train_loss_sum / train_seen
        train_acc  = 100. * train_correct / train_seen
        val_acc    = validate(feature_extractor, classifier, test_loader, device)
        epoch_elapsed = time.time() - epoch_start

                # Lưu lịch sử
        results["reward_history"].append(total_reward)
        results["train_loss_history"].append(train_loss)
        results["train_acc_history"].append(train_acc)
        results["val_acc_history"].append(val_acc)
        results["dqn_loss_history"].append(total_dqn_loss / max(len(train_loader), 1))
        results["time_history"].append(epoch_elapsed)

        # In summary
        print(f"Epoch {epoch+1:03}/{NUM_EPOCHS} | LastOpt {last_action_str} | "
              f"Reward {total_reward:.1f} | TrainLoss {train_loss:.4f} | "
              f"TrainAcc {train_acc:.2f}% | ValAcc {val_acc:.2f}% | "
              f"DQN_Loss {results['dqn_loss_history'][-1]:.4f} | Time {epoch_elapsed:.1f}s")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                "feature_extractor": feature_extractor.state_dict(),
                "classifier": classifier.state_dict()
            }, "best_model.pth")

    # ---- Sau khi training xong, lưu kết quả ra file ----
    import pickle
    with open("results.pkl", "wb") as f:
        pickle.dump(results, f)
    torch.save(feature_extractor.state_dict(), "feat.pth")
    torch.save(classifier.state_dict(),       "clf.pth")

    return results, feature_extractor, classifier, device, test_loader


# ==== Code cell: Plotting & Confusion Matrix (reuse variables) ====  
# Đảm bảo biến đã có sẵn, nếu chưa thì gọi main() để huấn luyện và lấy kết quả
try:
    results
except NameError:
    results, feature_extractor, classifier, device, test_loader = main()

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import torch

# 1️⃣ Vẽ biểu đồ tổng hợp
epochs = range(1, len(results["reward_history"]) + 1)
plt.figure(figsize=(14, 10))

plt.subplot(2, 2, 1)
plt.plot(epochs, results["reward_history"], label="Reward per Epoch")
plt.xlabel("Epoch"); plt.ylabel("Reward"); plt.title("Reward per Epoch"); plt.legend()

plt.subplot(2, 2, 2)
plt.plot(epochs, results["train_loss_history"], label="Train Loss per Epoch")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Train Loss per Epoch"); plt.legend()

plt.subplot(2, 2, 3)
plt.plot(epochs, results["train_acc_history"], label="Train Accuracy per Epoch")
plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title("Train Accuracy per Epoch"); plt.legend()

plt.subplot(2, 2, 4)
plt.plot(results["time_history"], results["dqn_loss_history"], label="DQN Loss vs Time")
plt.xlabel("Time (s)"); plt.ylabel("DQN Loss"); plt.title("DQN Loss vs Time"); plt.legend()

plt.tight_layout()
plt.show()

# 2️⃣ In ma trận nhầm lẫn
feature_extractor.eval()
classifier.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = classifier(feature_extractor(inputs))
        all_preds.append(outputs.argmax(dim=1).cpu())
        all_labels.append(labels.cpu())
all_preds  = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(cm)
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()


In [None]:
# 1️⃣ Vẽ biểu đồ tổng hợp
epochs = range(1, len(results["reward_history"]) + 1)
plt.figure(figsize=(14, 10))

plt.subplot(2, 2, 1)
plt.plot(epochs, results["reward_history"], label="Reward per Epoch")
plt.xlabel("Epoch"); plt.ylabel("Reward"); plt.title("Reward per Epoch"); plt.legend()

plt.subplot(2, 2, 2)
plt.plot(epochs, results["train_loss_history"], label="Train Loss per Epoch")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Train Loss per Epoch"); plt.legend()

plt.subplot(2, 2, 3)
plt.plot(epochs, results["train_acc_history"], label="Train Accuracy per Epoch")
plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title("Train Accuracy per Epoch"); plt.legend()

plt.subplot(2, 2, 4)
plt.plot(results["time_history"], results["dqn_loss_history"], label="DQN Loss vs Time")
plt.xlabel("Time (s)"); plt.ylabel("DQN Loss"); plt.title("DQN Loss vs Time"); plt.legend()

plt.tight_layout()
plt.show()

# 2️⃣ In ma trận nhầm lẫn
feature_extractor.eval()
classifier.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = classifier(feature_extractor(inputs))
        all_preds.append(outputs.argmax(dim=1).cpu())
        all_labels.append(labels.cpu())
all_preds  = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(cm)
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()
