In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [2]:
SEED = 172
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

TRAIN_RATIO = 0.8
BATCH_SIZE = 1024
LEARNING_RATE = 0.001
EPOCH = 180

INPUT_SIZE = 30
TRANSFORMER_INPUT_SIZE = 128
NUM_HEADS = 8
NUM_LAYERS = 6
OUTPUT_SIZE = 6

In [3]:
class Dataset(Dataset):
    def __init__(self, path, transform=None):
        with open(path, 'r') as file:
            content = file.read()
            data_points = content.split('EOT')

            data_points = [dp.strip() for dp in data_points if dp.strip()]
            data_points = [dp.split('\n') for dp in data_points]
            data_points = [[[float(cell) for cell in row.split(', ')] for row in dp] for dp in data_points]
            self.original_targets = np.array([dp[0] for dp in data_points])
            input_points = [dp[1:] for dp in data_points]
            targets_2 = np.delete(self.original_targets, 1, 1)
            targets_2 = np.hstack((targets_2, np.cos(self.original_targets[:, 1])[..., None]))
            self.targets_cos_sin = torch.tensor(np.hstack((targets_2, np.sin(self.original_targets[:, 1])[..., None])))
            inputs = []
            for input in input_points:
                combined = []
                for coordinate in input:
                    combined += coordinate
                inputs.append(combined)
            self.inputs = torch.tensor(np.array(inputs))
            # self.inputs = torch.tensor(np.array(input_points))

    def __len__(self):
        return len(self.original_targets)

    def __getitem__(self, idx):
        target = self.targets_cos_sin[idx]
        input = self.inputs[idx]
        return input, target
    
class Transformer(nn.Module):
    def __init__(self, input_size, model_input_size, num_of_heads, num_of_layers, output_size):
        super().__init__()
        self.input_projection = nn.Linear(input_size, model_input_size)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=model_input_size, nhead=num_of_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_of_layers)
        self.output_layer = nn.Linear(model_input_size, output_size)
    def forward(self, x):
        x = self.input_projection(x)
        out = self.encoder(x)
        # out = out.mean(dim=1)
        return self.output_layer(out)
    
class MSEWithTrigConstraint(nn.Module):
    def __init__(self, trig_lagrangian=1000, device='cuda'):
        super().__init__()
        self.trig_lagrangian = trig_lagrangian
        self.mse_loss = nn.MSELoss()
        self.device = device
    
    def forward(self, prediction, target):
        total_loss = self.mse_loss(prediction, target)
        cos_values = prediction[:, 4]
        sin_values = prediction[:, 5]
        norm = torch.sqrt(torch.square(cos_values) + torch.square(sin_values)).to(self.device)
        ones = torch.ones(*norm.size()).to(self.device)
        constraint_loss = torch.abs(ones - norm)
        total_loss += self.trig_lagrangian * torch.mean(constraint_loss)
        return total_loss

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

dataset = Dataset(path='tracks.txt')
train_size = int(TRAIN_RATIO * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

model = Transformer(INPUT_SIZE, TRANSFORMER_INPUT_SIZE, NUM_HEADS, NUM_LAYERS, OUTPUT_SIZE)
model = model.to(device)
if torch.cuda.is_available():
    model.cuda()
    unet_model = nn.DataParallel(model)

criterion = MSEWithTrigConstraint(trig_lagrangian=100000, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 25, 40], gamma=0.5)

Using device: cuda


In [5]:
prev_model_path = 'transformer_trig_constraint/transformer_trig_epoch_180.pth'

In [6]:
print(f"Loading model from {prev_model_path}")
checkpoint = torch.load(prev_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
scheduler.load_state_dict(checkpoint['scheduler'])

Loading model from transformer_trig_constraint/transformer_trig_epoch_180.pth


In [7]:
# iter_val_dataloader = iter(val_dataloader)
import plotly.graph_objects as go

LINE_WIDTH = 2.0

def track(phi, d0, phi0, pt, dz, tanl):
    alpha = 1 / 2  # 1/cB, magnetic field and charge effect factor
    q = 1  # charge of the particle
    kappa = q / pt
    rho = alpha / kappa
    x = d0 * np.cos(phi0) + rho * (np.cos(phi0) - np.cos(phi0 + phi))
    y = d0 * np.sin(phi0) + rho * (np.sin(phi0) - np.sin(phi0 + phi))
    z = dz - rho * tanl * phi
    return x, y, z

def plot_helix_interactive(target, output, distinct_points=None):
    phi = np.linspace(0, 4 * np.pi, 1000)
    target_x, target_y, target_z = track(phi, *target)
    output_x, output_y, output_z = track(phi, *output)

    fig = go.Figure(data=[go.Scatter3d(x=target_x, y=target_y, z=target_z,
                                    #    marker=dict(size=4, color=target_z), # Viridis
                                       mode='lines',
                                       line=dict(color='red', width=LINE_WIDTH),
                                       name='Target_Path')
                         ])
    
    fig.add_trace(go.Scatter3d(x=output_x, y=output_y, z=output_z,
                                    #    marker=dict(size=4, color=output_z),
                                       mode='lines',
                                       line=dict(color='black', width=LINE_WIDTH),
                                       name='Output_Path'
                         ))
    
    if distinct_points is not None:
        fig.add_trace(go.Scatter3d(x=distinct_points[:, 0], y=distinct_points[:, 1], z=distinct_points[:, 2],
                                   mode='markers',
                                   marker=dict(size=3, color='blue')))

    fig.update_layout(title='Interactive 3D Helix',
                      scene=dict(xaxis_title='X Axis',
                                 yaxis_title='Y Axis',
                                 zaxis_title='Z Axis'),
                      margin=dict(l=0, r=0, b=0, t=0))  # Tight layout
    fig.show()

In [8]:
import matplotlib.pyplot as plt

LINE_WIDTH = 0.5

def plot_helix_static(target, output, distinct_points=None, view_angle=None):
    phi = np.linspace(0, 4 * np.pi, 1000)
    target_x, target_y, target_z = track(phi, *target)
    output_x, output_y, output_z = track(phi, *output)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    
    # Plotting target path
    ax.plot(target_x, target_y, target_z, label='Target Path', color='red', linewidth=LINE_WIDTH)
    
    # Plotting output path
    ax.plot(output_x, output_y, output_z, label='Output Path', color='black', linewidth=LINE_WIDTH)
    
    # Plotting distinct points if available
    if distinct_points is not None:
        ax.scatter(distinct_points[:, 0], distinct_points[:, 1], distinct_points[:, 2], color='blue', s=10)

    ax.legend()
    ax.set_xlabel('X Axis')
    ax.set_ylabel('Y Axis')
    ax.set_zlabel('Z Axis')
    plt.title('Helix Visualization')

    # Setting the view angle
    if view_angle == 'top':
        ax.view_init(elev=90, azim=90)  # Top view
    elif view_angle == 'side':
        ax.view_init(elev=0, azim=0)  # Side view

    plt.show()

In [9]:
test_input = None
test_target = None
test_output = None
temp_device = 'cuda'

INDEX = 6

with torch.no_grad():
    for inputs, targets in train_dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        # test_input = inputs[INDEX].to(temp_device)
        # test_target = targets[INDEX].to(temp_device)
        # test_output = outputs[INDEX].to(temp_device)
        break

print(test_input)
print(f'Target: {test_target}')
print(f'Output: {test_output}')
# plot_helix_interactive(test_target, test_output, test_input)

In [None]:
# Plot top view
# test_output[1] -= 0.25
plot_helix_static(test_target, test_output, test_input, view_angle='top')

# Plot side view
plot_helix_static(test_target, test_output, test_input, view_angle='side')