In [5]:
from bw_net import BWNet
import torch
from torch.utils.data import DataLoader
from dataloader import ARC_Dataset
from torch import optim
import torch.nn.functional as F

def criterion(y_pred,y):
    y = y.long()
    ce = F.cross_entropy(y_pred,y)
    return ce

train_challenge = './kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
train_solution = "./kaggle/input/arc-prize-2024/arc-agi_training_solutions.json"
eval_challenge = "./kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json"
eval_solution = "./kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json"


kwargs = {
    'epochs': 5,
    'task_numbers': 10, #equal to the number of tasks
    'task_data_num': 1,
    'example_data_num': 20, #equal to inner model batch size
    'inner_lr': 0.01,
    'outer_lr': 0.001,
    
}
train_dataset = ARC_Dataset(train_challenge, train_solution)
train_loader = DataLoader(train_dataset, batch_size=kwargs['task_numbers'], shuffle=True)

eval_dataset = ARC_Dataset(train_challenge, train_solution)
eval_loader = DataLoader(eval_dataset, batch_size=kwargs['task_numbers'], shuffle=False)

model = BWNet()
optimizer= optim.AdamW(model.parameters(),lr=kwargs['outer_lr'])

for epoch in range(kwargs['epochs']):
    for data in train_loader:
        input_tensor, output_tensor, example_input, example_output = data
        for task_number in range(kwargs['task_numbers']):
            task_input = input_tensor[task_number]
            task_output = output_tensor[task_number]
            ex_input = example_input[task_number]
            ex_output = example_output[task_number]
            model.train()
            optimizer.zero_grad()
            prediction = model(task_input, ex_input, ex_output)
            _, predicted = torch.max(prediction.data, 1)
            task_output = task_output.squeeze(1)
            loss = criterion(prediction, task_output)
            #print(f'Task {task_number+1}/{kwargs["task_numbers"]}, Loss: {loss.item()}')
            loss.backward()
            optimizer.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (input_tensor, output_tensor, example_input, example_output) in enumerate(eval_loader):
            prediction = model(input_tensor, example_input, example_output)
            _, predicted = torch.max(prediction.data, 1)
            total += example_output.size(0)
            correct += (predicted == example_output).sum().item()

    print(f'Epoch {epoch+1}/{kwargs["epochs"]}, Loss: {loss.item()}, Accuracy: {100 * correct / total}%')