In [None]:
!pip install mlflow


In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from transformers import Trainer, TrainingArguments, AdamW
import torch
import torch.nn.functional as F
import mlflow
import mlflow.pytorch
import math
import os

In [None]:
config = GPT2Config()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
config.n_embd = 256
config.n_head = 4
config.n_layer = 4
config.n_positions = 128
model_kan = GPT2LMHeadModel(config)

In [11]:
class DataLoaderLite:
  def __init__(self, B, T, process_rank, num_processes, train_ratio=0.6, valid_ratio=0.2, mode='train'):
      self.B = B
      self.T = T
      self.process_rank = 0
      self.num_processes = 1
      self.train_ratio = train_ratio
      self.valid_ratio = valid_ratio
      self.mode = mode

      with open('input.txt', 'r') as f:
          text = f.read()
      enc = tokenizer
      tokens = enc.encode(text)
      self.tokens = torch.tensor(tokens)

      total_length = len(self.tokens)
      split_index_train = int(total_length * self.train_ratio)
      split_index_val = int(total_length * (self.train_ratio + self.valid_ratio))

      if self.mode == 'train':
          self.tokens = self.tokens[:split_index_train]
      elif self.mode == 'val':
          self.tokens = self.tokens[split_index_train:split_index_val]
      else:
          self.tokens = self.tokens[split_index_val:]

      if self.process_rank == 0:
          print(f"Loaded {len(self.tokens)} tokens for {self.mode} set")

      self.current_position = self.B * self.T * self.process_rank

  def next_batch(self):
      B, T = self.B, self.T
      buf = self.tokens[self.current_position : self.current_position + B * T * self.num_processes + 1]
      x = (buf[:-1]).view(B, T)  # inputs
      y = (buf[1:]).view(B, T)    # targets
      self.current_position += B * T * self.num_processes
      if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
          self.current_position = self.B * self.T * self.process_rank
      return x, y

B = 32  
T = 32  
train_ratio = 0.6  
valid_ratio = 0.2   
test_ratio = 0.2    

train_loader = DataLoaderLite(B, T, train_ratio, valid_ratio, mode='train')
val_loader = DataLoaderLite(B, T, train_ratio, valid_ratio, mode='val')
test_loader = DataLoaderLite(B, T, train_ratio, valid_ratio, mode='test')

x_train, y_train = train_loader.next_batch()
print("Training batch:", x_train.shape, y_train.shape)

x_val, y_val = val_loader.next_batch()
print("Validation batch:", x_val.shape, y_val.shape)

x_test, y_test = test_loader.next_batch()
print("Test batch:", x_test.shape, y_test.shape)

Token indices sequence length is longer than the specified maximum sequence length for this model (338025 > 1024). Running this sequence through the model will result in indexing errors


Loaded 202815 tokens for train set
Loaded 67605 tokens for val set
Loaded 67605 tokens for test set
Training batch: torch.Size([32, 32]) torch.Size([32, 32])
Validation batch: torch.Size([32, 32]) torch.Size([32, 32])
Test batch: torch.Size([32, 32]) torch.Size([32, 32])


In [6]:
# mlflow_log_dir = '/content/drive/My Drive/KAN/mlflow_logs_2'
mlflow_log_dir = 'mlflow_logs'
os.makedirs(mlflow_log_dir, exist_ok=True)
mlflow.set_tracking_uri(mlflow_log_dir)

In [7]:
class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

In [8]:
# model_kan = model

# Set the weights in the MLP layer to random weights
# for name, param in model.named_parameters():
#     if 'mlp' in name:
#         param.data = torch.randn_like(param)

# for name, param in model.named_parameters():
#   if 'mlp' not in name:
#       param.requires_grad = False



# Function to replace the MLP layers with KAN layers
def replace_mlp_with_kan(model, kan_layer):
    for block in model.transformer.h:
        input_dim = block.mlp.c_fc.weight.size(1)
        output_dim = block.mlp.c_proj.weight.size(0)

        block.mlp = kan_layer

KANLayer = KAN([256, 256])
replace_mlp_with_kan(model_kan, KANLayer)

In [None]:
model = KAN([256, 256])

total_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters: {total_params}')

Total number of parameters: 655360


In [9]:
optimizer = AdamW(model_kan.parameters(), lr=5e-5)
tokenizer.pad_token = tokenizer.eos_token

num_train_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_kan.to(device)

experiment_name = "My_Experiment"  
mlflow.set_experiment(experiment_name)

mlflow.start_run()

mlflow.log_param("learning_rate", 5e-5)
mlflow.log_param("num_train_epochs", num_train_epochs)

model_kan.train()
for epoch in range(num_train_epochs):
  for step in range(len(train_loader.tokens) // (B * T)):
      batch = train_loader.next_batch()
      input_ids = batch[0].to(device)
      labels = batch[1].to(device)

      outputs = model_kan(input_ids, labels=labels)
      loss = outputs.loss

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      if step % 100 == 0:
          print(f"Epoch {epoch + 1}, Step {step}, Training Loss: {loss.item()}")
          mlflow.log_metric("training_loss", loss.item(), step=step + epoch * (len(train_loader.tokens) // (B * T)))

  # Validation phase
  model_kan.eval()
  val_loss = 0.0
  num_val_batches = len(val_loader.tokens) // (B * T)

  with torch.no_grad():  
      for val_step in range(num_val_batches):
          val_batch = val_loader.next_batch()
          val_input_ids = val_batch[0].to(device)
          val_labels = val_batch[1].to(device)

          val_outputs = model_kan(val_input_ids, labels=val_labels)
          val_loss += val_outputs.loss.item()

  val_loss /= num_val_batches
  print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}")
  mlflow.log_metric("validation_loss", val_loss, step=epoch)

# Log the model
mlflow.pytorch.log_model(model_kan, "model")
mlflow.end_run()
print("Training complete.")

2024/08/06 17:26:17 INFO mlflow.tracking.fluent: Experiment with name 'My_Experiment' does not exist. Creating a new experiment.


Epoch 1, Step 0, Training Loss: 10.909412384033203
Epoch 1, Step 100, Training Loss: 9.147011756896973
Epoch 1, Validation Loss: 7.740481947407578
Epoch 2, Step 0, Training Loss: 7.573943614959717
Epoch 2, Step 100, Training Loss: 7.071108818054199
Epoch 2, Validation Loss: 6.687604665756226
Epoch 3, Step 0, Training Loss: 6.413384437561035
Epoch 3, Step 100, Training Loss: 6.687142372131348
Epoch 3, Validation Loss: 6.552627390081232
Epoch 4, Step 0, Training Loss: 6.235752582550049
Epoch 4, Step 100, Training Loss: 6.5607757568359375
Epoch 4, Validation Loss: 6.501803174163356
Epoch 5, Step 0, Training Loss: 6.136997699737549
Epoch 5, Step 100, Training Loss: 6.459405422210693
Epoch 5, Validation Loss: 6.470811193639582
Epoch 6, Step 0, Training Loss: 6.049208641052246
Epoch 6, Step 100, Training Loss: 6.352814674377441
Epoch 6, Validation Loss: 6.430894244800914
Epoch 7, Step 0, Training Loss: 5.951621055603027
Epoch 7, Step 100, Training Loss: 6.228390216827393
Epoch 7, Validation 



Epoch 10, Validation Loss: 6.366308920311205
Training complete.


In [10]:
def evaluate_perplexity(model, dataloader):
  model.eval()
  total_loss = 0.0
  total_count = 0

  with torch.no_grad():
      for step in range(len(dataloader.tokens) // (B * T)):
          batch = dataloader.next_batch()
          input_ids = batch[0].to(device)
          labels = batch[1].to(device)

          outputs = model(input_ids, labels=labels)
          loss = outputs.loss
          total_loss += loss.item() * input_ids.size(0)
          total_count += input_ids.size(0)

  avg_loss = total_loss / total_count
  perplexity = math.exp(avg_loss)
  return perplexity

# Evaluate perplexity on the test dataset
test_perplexity = evaluate_perplexity(model_kan, test_loader)
print(f'Test Perplexity: {test_perplexity}')

Test Perplexity: 620.127397026103
