In [None]:
import torch
import sys
import torch.onnx
import torch.utils
import torch.utils.data
import torch.nn.functional as F
import torch.nn as nn
from model import ResNet

device = torch.device("cuda:0")

columns = 5
rows = 5
channels = 64
layers = 15
epochs = 100
training_batch_size = 64
inference_batch_size = 256
kl_loss_scale = 0.1

data_folder = "../data"
models_folder = "../models"
generation = 1

class Snapshots(torch.utils.data.Dataset):
    def __init__(self, file_name):
        self.data = [[], [], [], []]
        i = 0
        with open(file_name) as f:
            for line in f.readlines():
                if line.strip() == "":
                    i = 0
                    continue

                t = torch.tensor([float(x) for x in line.split(", ")])

                if i == 0:
                    t = t.view(7, columns, rows)
                self.data[i].append(t)
                i += 1

    def __len__(self):
        return len(self.data[0])

    def __getitem__(self, index):
        return [self.data[x][index] for x in range(4)]

In [None]:
def loss_fn(wp_out, sp_out, vs_out, wp_label, sp_label, vs_label):
    kl_div = nn.KLDivLoss(reduction='sum')
    mse = nn.MSELoss(reduction='sum')
    
    # Concatenate the outputs to form the complete action distribution
    actions_out = torch.cat([wp_out, sp_out], dim=1)
    # Apply log_softmax to convert to log probabilities
    log_probs = F.log_softmax(actions_out, dim=1)
    
    # Concatenate the labels to form the complete target distribution
    actions_label = torch.cat([wp_label, sp_label], dim=1)
    
    # Compute the KL divergence loss
    kl_loss = kl_loss_scale * kl_div(log_probs, actions_label)
    
    # Compute the MSE loss for the scalar output
    mse_loss = mse(vs_out, vs_label)

    return (kl_loss, mse_loss)


def save_model(model, folder):
    torch.save(model, f"{folder}/model_{generation}.pt")
    input_names = ["States"]
    output_names = ["WallPriors", "StepPriors", "Values"]
    dummy_input = torch.randn(inference_batch_size, 7, columns, rows).to(device)
    torch.onnx.export(
        model,
        dummy_input,
        f"{folder}/model_{generation}.onnx",
        input_names=input_names,
        output_names=output_names,
    )

In [None]:
model = ResNet(columns, rows, channels, layers).to(device)

In [None]:
training_window = range((generation - 1) // 2, generation)
snapshots = torch.utils.data.ConcatDataset(
    [Snapshots(f"{data_folder}/snapshots_{i}.csv") for i in training_window]
)
training_data, eval_data = torch.utils.data.random_split(snapshots, [0.8, 0.2])
training_loader = torch.utils.data.DataLoader(
    training_data,
    batch_size=training_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)
eval_loader = torch.utils.data.DataLoader(
    eval_data,
    batch_size=training_batch_size,
    num_workers=4,
    pin_memory=True,
    shuffle=False,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1)

In [None]:
import matplotlib.pyplot as plt

def plot_tensor(mat, cmap='viridis'):
    """
    Plots a 2D PyTorch tensor as an n x n grid of squares with colors representing the tensor values.

    Parameters:
    - mat: 2D PyTorch tensor of floats.
    - cmap: Colormap for visualizing the values in the tensor.
    """
    # Ensure mat is a 2D tensor
    if mat.dim() != 2:
        raise ValueError("Input tensor must be 2D")

    # Convert the tensor to a NumPy array
    matrix = mat.numpy()

    # Plotting the matrix
    plt.figure(figsize=(6,6))  # Adjust the figure size as needed
    plt.imshow(matrix, cmap=cmap, interpolation='nearest')  # Use specified colormap
    plt.colorbar()  # Show color scale
    plt.xticks(range(matrix.shape[0]))  # Adjust ticks based on tensor size
    plt.yticks(range(matrix.shape[1]))
    plt.grid(False)  # Turn off the grid
    plt.show()

In [None]:
plot_tensor(snapshots[3][0][0])

In [None]:
len(snapshots)

In [None]:


try:
    for epoch in range(1000):
        total_kl_loss = 0
        total_mse_loss = 0
        for states, wall_priors, step_priors, values in training_loader:
            states = states.to(device)
            wall_priors = wall_priors.to(device)
            step_priors = step_priors.to(device)
            values = values.to(device)
    
            optimizer.zero_grad()
            wp, sp, vs = model.forward(states)
            kl_loss, mse_loss = loss_fn(wp, sp, vs, wall_priors, step_priors, values)
            total_kl_loss += float(kl_loss)
            total_mse_loss += float(mse_loss)

            loss = kl_loss + mse_loss
            loss.backward()
            optimizer.step()
            del loss
        print(
            f"Training loss in epoch {epoch} of generation {generation}: {total_kl_loss / len(training_data)} + {total_mse_loss / len(training_data)} = {(total_kl_loss + total_mse_loss) / len(training_data)}."
        )
        
        model.train(False)
        total_kl_loss = 0
        total_mse_loss = 0
        for states, wall_priors, step_priors, values in eval_loader:
            states = states.to(device)
            wall_priors = wall_priors.to(device)
            step_priors = step_priors.to(device)
            values = values.to(device)
            wp, sp, vs = model.forward(states)
            kl_loss, mse_loss = loss_fn(wp, sp, vs, wall_priors, step_priors, values)
            total_kl_loss += float(kl_loss)
            total_mse_loss += float(mse_loss)
        print(
            f"Evaluation loss in epoch {epoch} of generation {generation}: {total_kl_loss / len(eval_data)} + {total_mse_loss / len(eval_data)} = {(total_kl_loss + total_mse_loss) / len(eval_data)}."
        )
except KeyboardInterrupt:
    print("Trainig was interrupted.")

In [None]:
initial_state = torch.unsqueeze(snapshots[0][0], 0).to(device)

In [None]:
model.forward(initial_state)

In [None]:
snapshots[0]

In [None]:
save_model(model, models_folder)