In [2]:
channels    = [256, 512, 512]
latent_dim  = 512
hidden_size = 512

def count_parameters(model):
    """Count the trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model = Magician(channels, latent_dim, hidden_size)
model_size = count_parameters(model)
print(f"model {int(model_size/1e6)}M")
print("encoder size:", count_parameters(model.encoder))
print("decoder size:", count_parameters(model.decoder))
print("mlp_key size:", count_parameters(model.mlp_key))
print("mlp_map size:", count_parameters(model.mlp_map))

model 38M
encoder size: 13292288
decoder size: 21661963
mlp_key size: 2100736
mlp_map size: 1052160


In [3]:
from dataloader_pairs import ARC_Dataset
from torch.utils.data import DataLoader

train_challenge = './kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
train_solution = "./kaggle/input/arc-prize-2024/arc-agi_training_solutions.json"
eval_challenge = "./kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json"
eval_solution = "./kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json"

kwargs = {
    'epochs': 100,
    'task_numbers': 1, #equal to the number of tasks
    'task_data_num': 1,
    'example_data_num': 5, #equal to inner model batch size
    'inner_lr': 0.001,
    'outer_lr': 0.001,
    'embed_size': 1,
    
}


train_dataset = ARC_Dataset(train_challenge, train_solution)
train_loader = DataLoader(train_dataset, batch_size=kwargs['task_numbers'], shuffle=True)

eval_dataset = ARC_Dataset(train_challenge, train_solution)
eval_loader = DataLoader(eval_dataset, batch_size=kwargs['task_numbers'], shuffle=False)

ti, to, ei, eo = next(iter(train_loader))

print(ti.shape, to.shape, ei.shape, eo.shape)

torch.Size([1, 1, 1, 30, 30]) torch.Size([1, 1, 1, 30, 30]) torch.Size([1, 6, 1, 30, 30]) torch.Size([1, 6, 1, 30, 30])


In [4]:
import torch.optim as optim

# CUDA 사용 가능 여부 확인
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else device  
print(f'Using {device} device')

model = model.to(device)

# 총 클래스의 수
num_classes = 11

# 0번과 1번 클래스에 부여할 가중치 (10% 이하로 설정)
weight_0 = 0.04
weight_1 = 0.05

# 나머지 9개 클래스에 공평하게 가중치 부여
remaining_weight = 1.0 - (weight_0 + weight_1)
weight_other = remaining_weight / (num_classes - 2)

# 가중치 리스트 생성
class_weights = [weight_0, weight_1] + [weight_other] * (num_classes - 2)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
                                                                           
def criterion(y_pred, y):
    y = y.long().squeeze(1)
    ce = F.cross_entropy(y_pred, y, weight=class_weights_tensor)
    # ce = F.cross_entropy(y_pred, y)
    return ce

optimizer = optim.Adam(model.parameters(), lr=0.001)

Using cuda device


In [6]:

tensor1 = torch.randn(1, 10).to(device)
tensor2 = torch.randn(1, 10).to(device)

concatenated_tensor = torch.cat((tensor1, tensor2), dim=1)
print(concatenated_tensor.shape)  # Should print torch.Size([1, 1024])

torch.Size([1, 20])


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    def __init__(self, C: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bnorm1 = nn.BatchNorm2d(C)
        self.bnorm2 = nn.BatchNorm2d(C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        r = self.conv1(self.relu(self.bnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.bnorm2(r)))
        return r + x

class ConvBlock(nn.Module):
    def __init__(self, mode: str, C_in: int, C_out: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU()
        self.bnorm = nn.BatchNorm2d(C_out)
        if mode == "down":
            self.conv = nn.Conv2d(C_in, C_out, kernel_size=4, stride=2, padding=0)
        elif mode == "up":
            self.conv = nn.ConvTranspose2d(C_in, C_out, kernel_size=4, stride=2, padding=0)
        elif mode == "same":
            self.conv = nn.Conv2d(C_in, C_out, kernel_size=3, padding=1)
        else:
            raise ValueError("Wrong ConvBlock mode.")
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, z):
        x = self.conv(z)
        x = self.bnorm(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class Encoder(nn.Module):
    def __init__(self, channels=[256, 512, 512], latent_dim=512, dropout=0.1):
        super(Encoder, self).__init__()
        self.conv1 = ConvBlock("same", 11,  channels[0], dropout)
        self.res12 = ResBlock(channels[0], dropout)
        self.conv2 = ConvBlock("same", channels[0], channels[1], dropout)
        self.res23 = ResBlock(channels[1], dropout)
        self.conv3 = ConvBlock("same", channels[1], channels[2], dropout)
    
    def one_hot_encode(self, z, num_classes=11):
        """
        One-hot encode the input tensor and adjust the shape.
        """
        return F.one_hot(z.squeeze(1).long(), num_classes=num_classes).permute(0, 3, 1, 2).float()

    def forward(self, z):
        # One-hot encode the input inside the encoder
        print('z', z.shape)
        z_one_hot = self.one_hot_encode(z)
        print('z_one_hot', z_one_hot.shape)
        residuals = [0] * 3
        x = self.conv1(z_one_hot)
        x = self.res12(x)
        residuals[0] = x
        x = self.conv2(x)
        x = self.res23(x)
        residuals[1] = x
        x = self.conv3(x)
        residuals[2] = x
        return x, residuals

# 모델 인스턴스 생성
encoder = Encoder(channels=[256, 512, 512], latent_dim=512, dropout=0.1)

# 예시 입력 데이터 생성
batch_size = 32
input_height = 30
input_width = 30
# 입력 데이터 (batch_size, 1, height, width)
input_data = torch.randint(0, 11, (batch_size, 1, input_height, input_width))

# 모델에 예시 입력 데이터 넣어서 동작 확인
encoded_output, residuals = encoder(input_data)

# 결과 출력
print(f"Encoded output shape: {encoded_output.shape}")
print(f"Residual 1 shape: {residuals[0].shape}")
print(f"Residual 2 shape: {residuals[1].shape}")
print(f"Residual 3 shape: {residuals[2].shape}")


z torch.Size([32, 1, 30, 30])
z_one_hot torch.Size([32, 11, 30, 30])
Encoded output shape: torch.Size([32, 512, 30, 30])
Residual 1 shape: torch.Size([32, 256, 30, 30])
Residual 2 shape: torch.Size([32, 512, 30, 30])
Residual 3 shape: torch.Size([32, 512, 30, 30])


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResBlock(nn.Module):
    def __init__(self, C: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bnorm1 = nn.BatchNorm2d(C)
        self.bnorm2 = nn.BatchNorm2d(C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        r = self.conv1(self.relu(self.bnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.bnorm2(r)))
        return r + x

class ConvBlock(nn.Module):
    def __init__(self, mode: str, C_in: int, C_out: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU()
        self.bnorm = nn.BatchNorm2d(C_out)
        if mode == "down":
            self.conv = nn.Conv2d(C_in, C_out, kernel_size=4, stride=2, padding=0)
        elif mode == "up":
            self.conv = nn.ConvTranspose2d(C_in, C_out, kernel_size=4, stride=2, padding=0)
        elif mode == "same":
            self.conv = nn.Conv2d(C_in, C_out, kernel_size=3, padding=1)
        else:
            raise ValueError("Wrong ConvBlock mode.")
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, z):
        x = self.conv(z)
        x = self.bnorm(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class Decoder(nn.Module):
    def __init__(self, channels=[256, 512, 512], dropout=0.1):
        super(Decoder, self).__init__()
        # 'channels' 리스트는 Decoder의 각 단계에서 채널 수를 정의합니다.
        # ConvBlock의 입력 채널 수를 정확히 맞춰야 합니다.
        self.conv3 = ConvBlock("same", channels[-1] * 2, channels[-2], dropout)  # 512 * 2 -> 512
        self.res32 = ResBlock(channels[-2], dropout)  # 512
        self.conv2 = ConvBlock("same", channels[-2] * 2, channels[-3], dropout)  # 512 * 2 -> 256
        self.res21 = ResBlock(channels[-3], dropout)  # 256
        self.conv1 = ConvBlock("same", channels[-3] * 2, channels[-3], dropout)  # 256 * 2 -> 256
        self.conv0 = nn.Conv2d(channels[-3], 11, kernel_size=3, padding=1)  # Final output with 11 channels

    def forward(self, x, residuals):
        # Concatenate and process with conv3
        x = torch.cat([x, residuals[2]], dim=1)  # [32, 512*2, 30, 30]
        print(x.shape)  # Debug print
        x = self.conv3(x)  # [32, 512, 30, 30]
        print(x.shape)  # Debug print
        x = self.res32(x)  # [32, 512, 30, 30]
        print(x.shape)  # Debug print
        
        # Concatenate and process with conv2
        x = torch.cat([x, residuals[1]], dim=1)  # [32, 512*2, 30, 30]
        print(x.shape)  # Debug print
        x = self.conv2(x)  # [32, 256, 30, 30]
        print(x.shape)  # Debug print
        x = self.res21(x)  # [32, 256, 30, 30]
        print(x.shape)  # Debug print
        
        # Concatenate and process with conv1
        x = torch.cat([x, residuals[0]], dim=1)  # [32, 256*2, 30, 30]
        print(x.shape)  # Debug print
        x = self.conv1(x)  # [32, 256, 30, 30]
        print(x.shape)  # Debug print

        # Final output layer
        x = self.conv0(x)  # [32, 11, 30, 30]
        return x

# 디코더 인스턴스 생성
decoder = Decoder(channels=[256, 512, 512], dropout=0.1)

# 인코더에서 얻은 출력 및 Residual
encoded_output = torch.randn(32, 512, 30, 30)  # torch.Size([32, 512, 30, 30])
residuals = [
    torch.randn(32, 256, 30, 30),  # Residual 1: torch.Size([32, 256, 30, 30])
    torch.randn(32, 512, 30, 30),  # Residual 2: torch.Size([32, 512, 30, 30])
    torch.randn(32, 512, 30, 30)   # Residual 3: torch.Size([32, 512, 30, 30])
]

# 디코더에 인코더 출력과 residuals 넣어서 결과 확인
decoder_output = decoder(encoded_output, residuals)

# 결과 출력
print(f"Decoder output shape: {decoder_output.shape}")


torch.Size([32, 1024, 30, 30])
torch.Size([32, 512, 30, 30])
torch.Size([32, 512, 30, 30])
torch.Size([32, 1024, 30, 30])
torch.Size([32, 256, 30, 30])
torch.Size([32, 256, 30, 30])
torch.Size([32, 512, 30, 30])
torch.Size([32, 256, 30, 30])
Decoder output shape: torch.Size([32, 11, 30, 30])


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# CBAM 모듈 정의
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAM(nn.Module):
    def __init__(self, channels, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(channels, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

# 테스트 코드
def test_cbam():
    # CBAM 인스턴스 생성 (예를 들어 채널 수가 512인 경우)
    cbam = CBAM(channels=512, ratio=16, kernel_size=7)

    # 가상의 인코더 출력 생성 (배치 크기: 32, 채널: 512, 크기: 30x30)
    encoder_output = torch.randn(32, 512, 30, 30)
    
    # CBAM 모듈을 통해 인코더 출력을 통과시킵니다.
    cbam_output = cbam(encoder_output)

    # 결과 출력
    print(f"CBAM output shape: {cbam_output.shape}")

test_cbam()

CBAM output shape: torch.Size([32, 512, 30, 30])


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# CBAM 모듈 정의
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAM(nn.Module):
    def __init__(self, channels, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(channels, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

class ResBlock(nn.Module):
    def __init__(self, C: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bnorm1 = nn.BatchNorm2d(C)
        self.bnorm2 = nn.BatchNorm2d(C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        r = self.conv1(self.relu(self.bnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.bnorm2(r)))
        return r + x

class ConvBlock(nn.Module):
    def __init__(self, mode: str, C_in: int, C_out: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU()
        self.bnorm = nn.BatchNorm2d(C_out)
        if mode == "down":
            self.conv = nn.Conv2d(C_in, C_out, kernel_size=4, stride=2, padding=0)
        elif mode == "up":
            self.conv = nn.ConvTranspose2d(C_in, C_out, kernel_size=4, stride=2, padding=0)
        elif mode == "same":
            self.conv = nn.Conv2d(C_in, C_out, kernel_size=3, padding=1)
        else:
            raise ValueError("Wrong ConvBlock mode.")
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, z):
        x = self.conv(z)
        x = self.bnorm(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class Encoder(nn.Module):
    def __init__(self, channels=[256, 512, 512], dropout=0.1):
        super(Encoder, self).__init__()
        self.conv1 = ConvBlock("same", 11,  channels[0], dropout)
        self.res12 = ResBlock(channels[0], dropout)
        self.conv2 = ConvBlock("same", channels[0], channels[1], dropout)
        self.res23 = ResBlock(channels[1], dropout)
        self.conv3 = ConvBlock("same", channels[1], channels[2], dropout)
    
    def one_hot_encode(self, z, num_classes=11):
        """
        One-hot encode the input tensor and adjust the shape.
        """
        return F.one_hot(z.squeeze(1).long(), num_classes=num_classes).permute(0, 3, 1, 2).float()

    def forward(self, z):
        # One-hot encode the input inside the encoder
        print('z', z.shape)
        z_one_hot = self.one_hot_encode(z)
        print('z_one_hot', z_one_hot.shape)
        residuals = [0] * 3
        x = self.conv1(z_one_hot)
        x = self.res12(x)
        residuals[0] = x
        x = self.conv2(x)
        x = self.res23(x)
        residuals[1] = x
        x = self.conv3(x)
        residuals[2] = x
        return x, residuals
    
class Decoder(nn.Module):
    def __init__(self, channels=[256, 512, 512], dropout=0.1):
        super(Decoder, self).__init__()
        # 'channels' 리스트는 Decoder의 각 단계에서 채널 수를 정의합니다.
        # ConvBlock의 입력 채널 수를 정확히 맞춰야 합니다.
        self.conv3 = ConvBlock("same", channels[-1] * 2, channels[-2], dropout)  # 512 * 2 -> 512
        self.res32 = ResBlock(channels[-2], dropout)  # 512
        self.conv2 = ConvBlock("same", channels[-2] * 2, channels[-3], dropout)  # 512 * 2 -> 256
        self.res21 = ResBlock(channels[-3], dropout)  # 256
        self.conv1 = ConvBlock("same", channels[-3] * 2, channels[-3], dropout)  # 256 * 2 -> 256
        self.conv0 = nn.Conv2d(channels[-3], 11, kernel_size=3, padding=1)  # Final output with 11 channels

    def forward(self, x, residuals):
        # Concatenate and process with conv3
        x = torch.cat([x, residuals[2]], dim=1)  # [32, 512*2, 30, 30]
        print(x.shape)  # Debug print
        x = self.conv3(x)  # [32, 512, 30, 30]
        print(x.shape)  # Debug print
        x = self.res32(x)  # [32, 512, 30, 30]
        print(x.shape)  # Debug print
        
        # Concatenate and process with conv2
        x = torch.cat([x, residuals[1]], dim=1)  # [32, 512*2, 30, 30]
        print(x.shape)  # Debug print
        x = self.conv2(x)  # [32, 256, 30, 30]
        print(x.shape)  # Debug print
        x = self.res21(x)  # [32, 256, 30, 30]
        print(x.shape)  # Debug print
        
        # Concatenate and process with conv1
        x = torch.cat([x, residuals[0]], dim=1)  # [32, 256*2, 30, 30]
        print(x.shape)  # Debug print
        x = self.conv1(x)  # [32, 256, 30, 30]
        print(x.shape)  # Debug print

        # Final output layer
        x = self.conv0(x)  # [32, 11, 30, 30]
        return x

class CBAMAEmodel(nn.Module):
    def __init__(self, channels=[256, 512, 512], embed_dim=512, dropout=0.1):
        super(CBAMAEmodel, self).__init__()
        self.preprocess_and_embed = Encoder(channels, dropout)
        self.decoder = Decoder(channels, dropout)
        self.cbam1 = CBAM(channels=embed_dim*2)  # edu_input과 edu_output 결합 후 적용
        self.cbam2 = CBAM(channels=embed_dim*2)  
        self.reduce_channels = nn.Conv2d(embed_dim*2, embed_dim*1, kernel_size=1)  # 채널 축소
        self.reduce_channels2 = nn.Conv2d(embed_dim*2, embed_dim*1, kernel_size=1)  # 채널 축소

    def forward(self, edu_input, edu_output, task_input):
        # edu_input, edu_output, task_input 임베딩
        embedded_edu_input, _ = self.preprocess_and_embed(edu_input)
        embedded_edu_output, _ = self.preprocess_and_embed(edu_output)
        embedded_task_input, residuals = self.preprocess_and_embed(task_input)

        # edu_input과 edu_output 결합
        combined_edu = torch.cat([embedded_edu_input, embedded_edu_output], dim=1)
        
        # combined_edu의 크기 출력
        print(f"combined_edu shape: {combined_edu.shape}")

        # CBAM 적용
        attended_edu = self.cbam1(combined_edu)
        print(f"attended_edu shape: {attended_edu.shape}")
        attended_edu = self.reduce_channels(attended_edu)
        print(f"attended_edu shape after reduction: {attended_edu.shape}")
        # attended_edu와 task_input 결합
        task_edu_cat = torch.cat([attended_edu, embedded_task_input], dim=1)
        print(f"task_edu_cat shape: {task_edu_cat.shape}")
        # task_causal_mix 처리 후 디코더 적용
        task_causal_mix = self.cbam2(task_edu_cat)
        print(f"task_causal_mix shape: {task_causal_mix.shape}")
        task_causal_mix = self.reduce_channels2(task_causal_mix)
        print(f"task_causal_mix shape after reduction: {task_causal_mix.shape}")
        final_output = self.decoder(task_causal_mix, residuals)

        return final_output
    
model = CBAMAEmodel(channels=[256, 512, 512], embed_dim=512)
batch_size = 32
input_height = 30
input_width = 30
task_input = torch.randint(0, 11, (batch_size, 1, input_height, input_width))
edu_input = torch.randint(0, 11, (batch_size, 1, input_height, input_width))
edu_output = torch.randint(0, 11, (batch_size, 1, input_height, input_width))

result = model(edu_input, edu_output, task_input)
print("Result shape:", result.shape)

z torch.Size([32, 1, 30, 30])
z_one_hot torch.Size([32, 11, 30, 30])
z torch.Size([32, 1, 30, 30])
z_one_hot torch.Size([32, 11, 30, 30])
z torch.Size([32, 1, 30, 30])
z_one_hot torch.Size([32, 11, 30, 30])
combined_edu shape: torch.Size([32, 1024, 30, 30])
attended_edu shape: torch.Size([32, 1024, 30, 30])
attended_edu shape after reduction: torch.Size([32, 512, 30, 30])
task_edu_cat shape: torch.Size([32, 1024, 30, 30])
task_causal_mix shape: torch.Size([32, 1024, 30, 30])
task_causal_mix shape after reduction: torch.Size([32, 512, 30, 30])
torch.Size([32, 1024, 30, 30])
torch.Size([32, 512, 30, 30])
torch.Size([32, 512, 30, 30])
torch.Size([32, 1024, 30, 30])
torch.Size([32, 256, 30, 30])
torch.Size([32, 256, 30, 30])
torch.Size([32, 512, 30, 30])
torch.Size([32, 256, 30, 30])
Result shape: torch.Size([32, 11, 30, 30])


In [23]:
channels    = [256, 512, 512]
latent_dim  = 512
hidden_size = 512

def count_parameters(model):
    """Count the trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model_size = count_parameters(model)
print(f"model {int(model_size/1e6)}M")


model 24M


Result shape: torch.Size([32, 11, 30, 30])
