In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import numpy as np
import pandas as pd
import os
import math
from sklearn.model_selection import train_test_split
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, lmax, base=1000):
        super().__init__()
        self.d_model = d_model
        self.lmax = lmax
        self.base = base
        
        position = torch.arange(0, lmax).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(base) / d_model))
        pe = torch.zeros(lmax, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, t):
        return self.pe[:t.size(1)]

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_in, d_model):
        super().__init__()
        self.mlp = nn.Linear(d_in, d_model)

    def forward(self, x):
        return self.mlp(x)

class FlashMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"
        
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, key_padding_mask=None):
        batch_size, seq_len, _ = x.shape
        qkv = self.qkv_proj(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, batch_size, num_heads, seq_len, head_dim]
        
        if key_padding_mask is not None:
            attention_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = (1.0 - attention_mask) * -10000.0
        else:
            attention_mask = None
        
        attn_output = flash_attn_func(qkv[0], qkv[1], qkv[2], causal=False, mask=attention_mask)
        attn_output = attn_output.reshape(batch_size, seq_len, self.d_model)
        
        return self.out_proj(attn_output)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.mha = FlashMultiHeadAttention(d_model, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, src_key_padding_mask=None):
        x = self.norm1(x + self.mha(x, key_padding_mask=src_key_padding_mask))
        x = self.norm2(x + self.ff(x))
        return x

class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])

    def forward(self, x, src_key_padding_mask=None):
        for layer in self.layers:
            x = layer(x, src_key_padding_mask)
        return x

class Decoder(nn.Module):
    def __init__(self, d_model, d_out):
        super().__init__()
        self.linear = nn.Linear(d_model, d_out)

    def forward(self, x):
        return self.linear(x)

class Astromer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, lmax, base=1000):
        super().__init__()
        self.positional_encoding = PositionalEncoding(d_model, lmax, base)
        self.fnn = FeedForwardNetwork(1, d_model)
        self.encoder = Encoder(d_model, num_heads, d_ff, num_layers)
        self.decoder = Decoder(d_model, 1)

    def forward(self, times, magnitudes, lengths, mask_prob=0.15):
        pe = self.positional_encoding(times)
        masked_magnitudes = self.mask_magnitudes(magnitudes, lengths, mask_prob)
        x = pe.unsqueeze(0).expand(magnitudes.size(0), -1, -1) + self.fnn(masked_magnitudes.unsqueeze(-1))
        padding_mask = self.create_padding_mask(lengths, times.size(1))
        encoded = self.encoder(x, src_key_padding_mask=padding_mask)
        reconstructed = self.decoder(encoded).squeeze(-1)
        return reconstructed

    def mask_magnitudes(self, magnitudes, lengths, mask_prob):
        mask = torch.rand_like(magnitudes) < mask_prob
        masked_magnitudes = magnitudes.clone()
        for i, length in enumerate(lengths):
            masked_magnitudes[i, :length][mask[i, :length]] = 0
        return masked_magnitudes

    def create_padding_mask(self, lengths, max_length):
        mask = torch.arange(max_length).expand(len(lengths), max_length) >= lengths.unsqueeze(1)
        return mask.to(lengths.device)


In [None]:
# Training function
def train(model, optimizer, train_loader, test_loader, num_epochs):
    criterion = nn.MSELoss(reduction='none')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for times, magnitudes, lengths in train_loader:
            times, magnitudes, lengths = times.to(device), magnitudes.to(device), lengths.to(device)
            optimizer.zero_grad()
            reconstructed = model(times, magnitudes, lengths)
            loss = criterion(reconstructed, magnitudes)
            mask = torch.arange(magnitudes.size(1), device=device).expand(magnitudes.size(0), magnitudes.size(1)) < lengths.unsqueeze(1)
            loss = (loss * mask.float()).sum() / mask.sum()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for times, magnitudes, lengths in test_loader:
                times, magnitudes, lengths = times.to(device), magnitudes.to(device), lengths.to(device)
                reconstructed = model(times, magnitudes, lengths)
                loss = criterion(reconstructed, magnitudes)
                mask = torch.arange(magnitudes.size(1), device=device).expand(magnitudes.size(0), magnitudes.size(1)) < lengths.unsqueeze(1)
                loss = (loss * mask.float()).sum() / mask.sum()
                test_loss += loss.item()
        
        test_loss /= len(test_loader)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")


In [7]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class LightCurveDataset(Dataset):
    def __init__(self, times, magnitudes, max_length):
        self.times = times
        self.magnitudes = magnitudes
        self.max_length = max_length

    def __len__(self):
        return len(self.times)

    def __getitem__(self, idx):
        time = self.times[idx]
        magnitude = self.magnitudes[idx]
        length = len(time)

        # Pad sequences to max_length
        padded_time = torch.zeros(self.max_length)
        padded_magnitude = torch.zeros(self.max_length)
        
        padded_time[:length] = torch.tensor(time, dtype=torch.float32)
        padded_magnitude[:length] = torch.tensor(magnitude, dtype=torch.float32)
        
        return padded_time, padded_magnitude, length

# Load the data
df = pd.read_csv('../Data/synthetic_light_curves.csv')

# Group the data by sample_id
grouped = df.groupby('sample_id')

# Prepare the data
times = [group['time_mjd'].values for _, group in grouped]
magnitudes = [group['magnitude'].values for _, group in grouped]
lengths = [len(t) for t in times]

# Find the maximum length
max_length = max(lengths)

# Split the data into train and test sets
train_times, test_times, train_mags, test_mags = train_test_split(
    times, magnitudes, test_size=0.2, random_state=42
)

# Create datasets
train_dataset = LightCurveDataset(train_times, train_mags, max_length)
test_dataset = LightCurveDataset(test_times, test_mags, max_length)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
print(f"Maximum sequence length: {max_length}")

Number of training samples: 800
Number of test samples: 200
Maximum sequence length: 100


In [None]:
# Set up the model and training
d_model = 128
num_heads = 4
d_ff = 256
num_layers = 3
lmax = max_length  # Use the actual maximum length from the data
base = 1000

model = Astromer(d_model, num_heads, d_ff, num_layers, lmax, base)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [None]:
# Train the model
train(model, optimizer, train_loader, test_loader, num_epochs=10)