In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
from dataloader_pairs import ARC_Dataset
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
import gc
import random
import random
import numpy as np
import torch

# Fix random seeds for reproducibility
def fix_random_seeds(seed_value=42):

    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

fix_random_seeds(77)

# CUDA 사용 가능 여부 확인
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else device  
print(f'Using {device} device')


# 총 클래스의 수
num_classes = 11

# 0번과 1번 클래스에 부여할 가중치 (10% 이하로 설정)
weight_0 = 0.04
weight_1 = 0.05

# 나머지 9개 클래스에 공평하게 가중치 부여
remaining_weight = 1.0 - (weight_0 + weight_1)
weight_other = remaining_weight / (num_classes - 2)

# 가중치 리스트 생성
class_weights = [weight_0, weight_1] + [weight_other] * (num_classes - 2)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

def criterion(y_pred, y):
    y = y.long().squeeze(1)
    ce = F.cross_entropy(y_pred, y, weight=class_weights_tensor)
    # ce = F.cross_entropy(y_pred, y)
    return ce

# CBAM 모듈 정의
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAM(nn.Module):
    def __init__(self, channels, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(channels, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

class ResBlock(nn.Module):
    def __init__(self, C: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bnorm1 = nn.BatchNorm2d(C)
        self.bnorm2 = nn.BatchNorm2d(C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        r = self.conv1(self.relu(self.bnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.bnorm2(r)))
        return r + x

class ConvBlock(nn.Module):
    def __init__(self, mode: str, C_in: int, C_out: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU()
        self.bnorm = nn.BatchNorm2d(C_out)
        if mode == "down":
            self.conv = nn.Conv2d(C_in, C_out, kernel_size=4, stride=2, padding=0)
        elif mode == "up":
            self.conv = nn.ConvTranspose2d(C_in, C_out, kernel_size=4, stride=2, padding=0)
        elif mode == "same":
            self.conv = nn.Conv2d(C_in, C_out, kernel_size=3, padding=1)
        else:
            raise ValueError("Wrong ConvBlock mode.")
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, z):
        x = self.conv(z)
        x = self.bnorm(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class Encoder(nn.Module):
    def __init__(self, channels=[256, 512, 512], dropout=0.1):
        super(Encoder, self).__init__()
        self.conv1 = ConvBlock("same", 11,  channels[0], dropout)
        self.res12 = ResBlock(channels[0], dropout)
        self.conv2 = ConvBlock("same", channels[0], channels[1], dropout)
        self.res23 = ResBlock(channels[1], dropout)
        self.conv3 = ConvBlock("same", channels[1], channels[2], dropout)
    
    def one_hot_encode(self, z, num_classes=11):
        """
        One-hot encode the input tensor and adjust the shape.
        """
        return F.one_hot(z.squeeze(1).long(), num_classes=num_classes).permute(0, 3, 1, 2).float()

    def forward(self, z):
        z_one_hot = self.one_hot_encode(z)
        residuals = [0] * 3
        x = self.conv1(z_one_hot)
        x = self.res12(x)
        residuals[0] = x
        x = self.conv2(x)
        x = self.res23(x)
        residuals[1] = x
        x = self.conv3(x)
        residuals[2] = x
        return x, residuals
    
class Decoder(nn.Module):
    def __init__(self, channels=[256, 512, 512], dropout=0.1):
        super(Decoder, self).__init__()
        self.conv3 = ConvBlock("same", channels[-1] * 2, channels[-2], dropout)
        self.res32 = ResBlock(channels[-2], dropout)
        self.conv2 = ConvBlock("same", channels[-2] * 2, channels[-3], dropout)
        self.res21 = ResBlock(channels[-3], dropout)
        self.conv1 = ConvBlock("same", channels[-3] * 2, channels[-3], dropout)
        self.conv0 = nn.Conv2d(channels[-3], 11, kernel_size=3, padding=1)

    def forward(self, x, residuals):
        x = torch.cat([x, residuals[2]], dim=1)
        x = self.conv3(x)
        x = self.res32(x)
        x = torch.cat([x, residuals[1]], dim=1)
        x = self.conv2(x)
        x = self.res21(x)
        x = torch.cat([x, residuals[0]], dim=1)
        x = self.conv1(x)
        x = self.conv0(x)
        return x

class CBAMAEmodel(nn.Module):
    def __init__(self, channels=[256, 512, 512], embed_dim=512, dropout=0.1):
        super(CBAMAEmodel, self).__init__()
        self.preprocess_and_embed = Encoder(channels, dropout)
        self.decoder = Decoder(channels, dropout)
        self.cbam1 = CBAM(channels=embed_dim*2)
        self.cbam2 = CBAM(channels=embed_dim*2)
        self.reduce_channels = nn.Conv2d(embed_dim*2, embed_dim*1, kernel_size=1)
        self.reduce_channels2 = nn.Conv2d(embed_dim*2, embed_dim*1, kernel_size=1)

    def forward(self, edu_input, edu_output, task_input):
        embedded_edu_input, _ = self.preprocess_and_embed(edu_input)
        embedded_edu_output, _ = self.preprocess_and_embed(edu_output)
        embedded_task_input, residuals = self.preprocess_and_embed(task_input)
        combined_edu = torch.cat([embedded_edu_input, embedded_edu_output], dim=1)
        attended_edu = self.cbam1(combined_edu)
        attended_edu = self.reduce_channels(attended_edu)
        task_edu_cat = torch.cat([attended_edu, embedded_task_input], dim=1)
        task_causal_mix = self.cbam2(task_edu_cat)
        task_causal_mix = self.reduce_channels2(task_causal_mix)
        final_output = self.decoder(task_causal_mix, residuals)
        return final_output

colors = ['#000000','#1E93FF','#F93C31','#4FCC30','#FFDC00',
          '#999999','#E53AA3','#FF851B','#87D8F1','#921231','#555555']
colormap = plt.matplotlib.colors.ListedColormap(colors)

def show_grid_side_by_side(*grids):
    num_grids = len(grids)
    fig, axes = plt.subplots(1, num_grids, figsize=(num_grids * 2.8, 2.8))

    if num_grids == 1:
        axes = [axes]  # 리스트로 변환하여 일관성 유지
    
    for ax, grid in zip(axes, grids):
        if grid.ndim == 4:
            grid = grid.squeeze()  # [1, 1, 30, 30] -> [30, 30]로 변환
        elif grid.ndim == 3:
            grid = grid[0]  # [1, 30, 30] -> [30, 30]로 변환
            
        ax.pcolormesh(grid, edgecolors=colors[-1], linewidth=0.5, cmap=colormap, vmin=0, vmax=10)
        ax.invert_yaxis()
        ax.set_aspect('equal')
        ax.axis('off')

    plt.show()

train_challenge = './kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
train_solution = "./kaggle/input/arc-prize-2024/arc-agi_training_solutions.json"
eval_challenge = "./kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json"
eval_solution = "./kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json"

kwargs = {
    'epochs': 100,
    'task_numbers': 16, #equal to the number of tasks
    'task_data_num': 1,
    'example_data_num': 5, #equal to inner model batch size
    'inner_lr': 0.001,
    'outer_lr': 0.001,
    'embed_size': 1,
    
}

train_dataset = ARC_Dataset(train_challenge, train_solution)
train_loader = DataLoader(train_dataset, batch_size=kwargs['task_numbers'], shuffle=True)

eval_dataset = ARC_Dataset(train_challenge, train_solution)
eval_loader = DataLoader(eval_dataset, batch_size=kwargs['task_numbers'], shuffle=False)

model = CBAMAEmodel(channels=[256, 512, 512], embed_dim=512)
optimizer= optim.AdamW(model.parameters(),lr=kwargs['outer_lr'])

def train_and_validate(model, train_loader, val_loader, criterion, optimizer, device, epochs):
    model.to(device)
    best_val_loss = float('inf')  # 초기값을 무한대로 설정
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}/{epochs}')
        
        # Training phase
        model.train()
        total_train_loss = 0
        train_samples = 0
        for data in tqdm(train_loader, desc='Training'):
            input_tensor, output_tensor, example_input, example_output = [d.to(device) for d in data]
            
            optimizer.zero_grad()
            output = model(example_input, example_output, input_tensor)
            loss = criterion(output, output_tensor)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item() * input_tensor.size(0)  # 각 배치의 손실을 샘플 수로 가중
            train_samples += input_tensor.size(0)
            
            del input_tensor, output_tensor, example_input, example_output
            gc.collect()
            torch.cuda.empty_cache()
        
        avg_train_loss = total_train_loss / train_samples
        print(f'Training Loss: {avg_train_loss:.4f}')
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in tqdm(val_loader, desc='Validation'):
                input_tensor, output_tensor, example_input, example_output = [d.to(device) for d in data]
                
                output = model(example_input, example_output, input_tensor)
                loss = criterion(output, output_tensor)
                
                total_val_loss += loss.item() * input_tensor.size(0)
                total += input_tensor.size(0)
                
                # Assuming classification task
                _, predicted = torch.max(output, 1)
                correct += (predicted == output_tensor).sum().item()
                
                del input_tensor, output_tensor, example_input, example_output
                gc.collect()
                torch.cuda.empty_cache()
        
        avg_val_loss = total_val_loss / total
        val_accuracy = 100 * correct / total
        print(f'Validation Loss: {avg_val_loss:.4f}')
        print(f'Validation Accuracy: {val_accuracy:.2f}%')
        
        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print('Best model saved')

train_and_validate(model, train_loader, eval_loader, criterion, optimizer, device, epochs=kwargs['epochs'])



Using cuda device
Epoch 1/100


Training:   0%|          | 0/26 [00:00<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [4, 1, 30, 30] at entry 0 and [6, 1, 30, 30] at entry 1