In [1]:
import math
from typing import List, Optional, Tuple, Union
from abc import ABC, abstractmethod

import datasets
import torch
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser
)

from modeling_qwen2 import (
    Qwen2RMSNorm, 
    Qwen2RotaryEmbedding, 
    Qwen2MLP, 
    Qwen2Attention, 
    Qwen2FlashAttention2, 
    Qwen2SdpaAttention, 
    Qwen2DecoderLayer, 
    Qwen2PreTrainedModel, 
    Qwen2Model, 
    Qwen2ForCausalLM,
)

from configuration_qwen2 import Qwen2Config

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)

# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [2]:
def free_memory():
    if not torch.cuda.is_available():
        logger.info("CUDA is not available. No GPU memory to free.")
        return
        
    initial_memory = torch.cuda.memory_allocated()
    logger.info(f"Initial GPU memory allocated: {initial_memory / 1024**3:.2f} GB")
    gc.collect()
    torch.cuda.empty_cache()

    final_memory = torch.cuda.memory_allocated()
    logger.info(f"Final GPU memory allocated: {final_memory / 1024**3:.2f} GB")

    freed_memory = initial_memory - final_memory
    logger.info(f"Freed GPU memory: {freed_memory / 1024**3:.2f} GB")

In [3]:
class MaskConfig(PretrainedConfig):
    def __init__(
        self,
        mode: str = None,
        value: Union[float, torch.Tensor] = None,
        size: torch.Size = None,
        **kwargs,
    ):
        self.mode = mode
        self.value = value
        self.size = size
        super().__init__(**kwargs)

class Mask(nn.Module):
    def __init__(self, mask_config: MaskConfig):
        super().__init__()
        self.config = mask_config
        self.size = mask_config.size
        assert self.size is not None, "Mask size must be specified."

        value = mask_config.value
        if mask_config.mode == "scalar":
            self.weight = nn.Parameter(torch.tensor(value if value is not None else 1.0))
        elif mask_config.mode in ("vector_input", "vector_output"):
            ones = self._get_ones(mask_config.mode)
            self.weight = nn.Parameter(value if value is not None else ones)
        else:
            raise ValueError(f"Unsupported mask mode: {mask_config.mode}")

        self._check_shape_compatibility()

    def _get_ones(self, mode: str) -> torch.Tensor:
        """Generates a tensor of ones based on mode and size."""
        dim = 0 if mode == "vector_output" else -1
        features = self.size[dim]
        if len(self.size) == 2 and mode == "vector_output":
            return torch.ones(features, 1)
        else:
            return torch.ones(features)
          

    def _check_shape_compatibility(self):
        """Raises ValueError if the mask shape is incompatible with its size."""
        try:
            in_test = torch.rand(self.size)
            out_test = self.weight * in_test
            assert out_test.shape == in_test.shape, (
                "After applying mask, the shape of input weight does not stay the same."
            )
        except RuntimeError:
            raise ValueError("Mask initialized with an incompatible shape.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.size != x.shape:
            logger.warning("Warning: Input shape does not match mask shape.")
        return x * self.weight

In [66]:
class ModuleWithMask(nn.Module, ABC):
    def __init__(self, *args, **kwargs):
        super(ModuleWithMask, self).__init__()

    @abstractmethod
    def forward(self, x):
        pass

class ModulesWithMasks(nn.Module, ABC):
    def __init__(self, *args, **kwargs):
        super(ModulesWithMasks, self).__init__()

    @abstractmethod
    def forward(self, x):
        pass
        
    @abstractmethod
    def get_raw_masks(self):
        pass
        
    @abstractmethod
    def get_constrained_masks(self):
        pass

In [243]:
in_features = 24
out_features = 50
lin1 = nn.Linear(in_features, out_features, bias=False)
lin2 = nn.Linear(in_features, out_features, bias=False)

In [250]:
dim = 1
a = lin1.weight.data * lin2.weight.data
torch.sum(a, dim=0, keepdim=True)

tensor([[ 0.1379, -0.0828,  0.0579,  0.0869, -0.2780,  0.0153,  0.1296, -0.1406,
          0.0408, -0.0731, -0.0952,  0.0755, -0.0227,  0.2615, -0.0460,  0.0749,
         -0.0550,  0.0341,  0.0937,  0.0736, -0.1101, -0.0626,  0.0726, -0.0169]])

In [253]:
sum(lin1.weight.data[:, 0] * lin2.weight.data[:, 0])

tensor(0.1379)

In [255]:
v = lin1.weight.data
v, v[0, :], v[:, 0]

(tensor([[ 0.1848,  0.1494,  0.1646,  ..., -0.0032, -0.0555, -0.1466],
         [-0.1290, -0.0756, -0.0223,  ..., -0.1856,  0.0470, -0.0283],
         [-0.1999, -0.1282, -0.0962,  ..., -0.0889, -0.0963, -0.1632],
         ...,
         [-0.0054, -0.0114,  0.0129,  ..., -0.0470,  0.0084,  0.1470],
         [-0.1827, -0.1948, -0.1714,  ..., -0.0504,  0.1090, -0.1485],
         [-0.0404,  0.1018, -0.0064,  ..., -0.1330,  0.0322, -0.0790]]),
 tensor([ 0.1848,  0.1494,  0.1646, -0.0410, -0.1585, -0.0646,  0.1476, -0.0333,
         -0.1931, -0.1558, -0.1577, -0.1567,  0.2007, -0.0615,  0.0303, -0.1447,
         -0.1068,  0.1688,  0.0174, -0.1891, -0.0137, -0.0032, -0.0555, -0.1466]),
 tensor([ 0.1848, -0.1290, -0.1999,  0.0977,  0.1508, -0.1669, -0.0660,  0.1283,
         -0.0557,  0.1243,  0.0842, -0.0416, -0.1729, -0.1151,  0.1024,  0.0945,
          0.0715, -0.0120, -0.1028,  0.0120, -0.0981,  0.0432,  0.1506,  0.0452,
          0.1851, -0.0963,  0.1388,  0.0774, -0.1656,  0.0938,  0.2018

In [81]:
a = torch.rand(5) * torch.rand(5)
torch.sum(a, dim=0, keepdim=True)

tensor([0.8753])

In [87]:
torch.sigmoid(torch.rand(5))
torch.relu(torch.rand(5))

tensor([0.5828, 0.3854, 0.0289, 0.7288, 0.4752])

In [89]:
torch.exp(torch.rand(5))

tensor([2.3835, 1.7104, 1.0538, 1.0995, 1.5433])

In [256]:
class Constrainer(nn.Module):
    """
    Take as input modules to calculate statistics.
    Forward method takes as input mask weights.
    If modules are Linears, need to take care of 
    """
    def __init__(self, component_weights, constrain_mode):
        super().__init__()
        self.statistics = None
        self.constrain_mode = constrain_mode
        if (self.constrain_mode == "spherical" and 
            all([w is not None for w in component_weights])):
            
            assert len(component_weights) == 2, (
                "Spherical constraint (SLERP) only supports 2 component weights"
            )
            self.dots = torch.sum(component_weights[0] * component_weights[1], dim=0, keepdim=True)
            self.theta_0s = torch.arccos(self.dots)
            self.sin_theta_0s = torch.sin(sin.theta_0s)
        
    def forward(self, mask_weights: List[torch.Tensor]):
        if any([w is None for w in mask_weights]):
            return mask_weights
            
        if self.constrain_mode == "identity":
            return mask_weights
        elif self.constrain_mode == "01":
            mask_weights = [torch.exp(w) for w in mask_weights]
            mask_weights = [w / sum(mask_weights) for w in mask_weights]
            return mask_weights
        elif self.constrain_mode == "-11":
            return mask_weights
        elif self.constrain_mode == "spherical":
            """ Reference implementation
            # Calculate initial angle between v0 and v1
            theta_0 = np.arccos(dot)
            sin_theta_0 = np.sin(theta_0)
        
            # Angle at timestep t
            theta_t = theta_0 * t
            sin_theta_t = np.sin(theta_t)
        
            # Finish the slerp algorithm
            s0 = np.sin(theta_0 - theta_t) / sin_theta_0
            s1 = sin_theta_t / sin_theta_0
            """
            assert len(mask_weights) == 2, (
                "Spherical constraint (SLERP) only supports 2 mask weights"
            )
            mask_weights = [torch.exp(w) for w in mask_weights]
            ts = mask_weights[0] / sum(mask_weights)
            # mask_weights = [w / sum(mask_weights) for w in mask_weights]
            
            # sin_theta_0s = torch.sin(self.theta_0s)

            # Angle at timestep t
            theta_ts = self.theta_0s * ts
            sin_theta_ts = torch.sin(theta_ts)

            # Finish calculating slerp factors
            S0 = torch.sin(self.theta_0s - theta_ts) / self.sin_theta_0s
            S1 = sin_theta_ts / self.sin_theta_0s
            return [S0, S1]
        else:
            raise ValueError(f"Does not support {self.constrain_mode} constraint yet!")

In [257]:
class LinearsWithMasks(ModulesWithMasks):
    def __init__(
        self,
        linears: List[nn.Linear],
        weight_modes: List[str] = ["scalar"],
        weight_values: List[float] = None,
        bias_modes: List[str] = ["scalar"],
        bias_values: List[float] = None,
        constrain_mode: str = "identity"
    ):
        super().__init__()

        if not all(isinstance(linear, nn.Linear) for linear in linears):
            raise ValueError("All elements in 'linears' must be instances of nn.Linear.")

        if weight_values is None or len(weight_values) != len(linears):
            raise ValueError(
                f"weight_values for masks: {weight_values} do not match with linear layers: {linears}"
            )
        if bias_values is None:
            bias_values = [None] * len(linears)
        if len(bias_values) != len(linears):
            raise ValueError(
                f"bias_values for masks: {bias_values} do not match with linear layers: {linears}"
            )

        self.linears = nn.ModuleList(linears)
        self.constrain_mode = constrain_mode

        self.weight_masks = nn.ModuleList([
            Mask(MaskConfig(mode, value, linear.weight.shape))
            for mode, value, linear in zip(weight_modes, weight_values, linears)
        ])
        self.weight_masks_constrainer = Constrainer(
            component_weights=[x.weight for x in linears],
            constrain_mode=constrain_mode
        )

        self.bias_masks = nn.ModuleList([
            Mask(MaskConfig(mode, value, linear.bias.shape)) if linear.bias is not None else None
            for mode, value, linear in zip(bias_modes, bias_values, linears)
        ])
        self.bias_masks_constrainer = Constrainer(
            component_weights=[x.bias if x.bias is not None else None for x in linears],
            constrain_mode=constrain_mode
        )

    def forward(self, x):
        constrained_weight_masks = self.weight_masks_constrainer(
            [mask.weight for mask in self.weight_masks]
        )
        weights = [
            weight_mask * linear.weight for weight_mask, linear
            in zip(constrained_weight_masks, self.linears)
        ]
        merged_weight = sum(weights)

        constrained_bias_masks = self.bias_masks_constrainer([
            mask.weight if mask is not None else None for mask in self.bias_masks
        ])
        biases = [
            bias_mask * linear.bias if linear.bias is not None and bias_mask is not None
            else linear.bias for bias_mask, linear in zip(constrained_bias_masks, self.linears)
        ]

        if all(b is None for b in biases):
            merged_bias = None
        else:
            merged_bias = sum(
                b if b is not None 
                else torch.zeros_like(merged_weight[:, 0])
                for b in biases
            )

        return nn.functional.linear(x, merged_weight, merged_bias)

    def get_raw_masks(self):
        with torch.no_grad():
            return dict(
                weight_masks=[m.weight for m in self.weight_masks],
                bias_masks=[m.weight if m is not None else None for m in self.bias_masks]
            )
        
    def get_constrained_masks(self):
        with torch.no_grad():
            constrained_weight_masks = self.weight_masks_constrainer(
                [mask.weight for mask in self.weight_masks]
            )
            constrained_bias_masks = self.bias_masks_constrainer([
                mask.weight if mask is not None else None for mask in self.bias_masks
            ])
            return dict(
                weight_masks=constrained_weight_masks,
                bias_masks=constrained_bias_masks
            )

In [125]:
from masks import LinearsWithMasks as LinearsWithMasksRef

In [258]:
def slerp(
    t: Union[float, np.ndarray],
    v0: Union[np.ndarray, torch.Tensor],
    v1: Union[np.ndarray, torch.Tensor],
    DOT_THRESHOLD: float = 0.9995,
    eps: float = 1e-8,
):
    """
    Spherical linear interpolation

    From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
    Args:
        t (float/np.ndarray): Float value between 0.0 and 1.0
        v0 (np.ndarray): Starting vector
        v1 (np.ndarray): Final vector
        DOT_THRESHOLD (float): Threshold for considering the two vectors as
                               colinear. Not recommended to alter this.
    Returns:
        v2 (np.ndarray): Interpolation vector between v0 and v1
    """
    is_torch = False
    if not isinstance(v0, np.ndarray):
        is_torch = True
        v0 = v0.detach().cpu().float().numpy()
    if not isinstance(v1, np.ndarray):
        is_torch = True
        v1 = v1.detach().cpu().float().numpy()

    # Copy the vectors to reuse them later
    v0_copy = np.copy(v0)
    v1_copy = np.copy(v1)

    # Normalize the vectors to get the directions and angles
    v0 = normalize(v0, eps)
    v1 = normalize(v1, eps)
    # import ipdb; ipdb.set_trace()

    # Dot product with the normalized vectors (can't use np.dot in W)
    dot = np.sum(v0 * v1)

    # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
    if np.abs(dot) > DOT_THRESHOLD:
        res = lerp(t, v0_copy, v1_copy)
        return maybe_torch(res, is_torch)

    # Calculate initial angle between v0 and v1
    theta_0 = np.arccos(dot)
    sin_theta_0 = np.sin(theta_0)

    # Angle at timestep t
    theta_t = theta_0 * t
    sin_theta_t = np.sin(theta_t)

    # Finish the slerp algorithm
    s0 = np.sin(theta_0 - theta_t) / sin_theta_0
    s1 = sin_theta_t / sin_theta_0
    return [dot, theta_0, theta_t, s0, s1]

def maybe_torch(v: np.ndarray, is_torch: bool):
    if is_torch:
        return torch.from_numpy(v)
    return v


def normalize(v: np.ndarray, eps: float):
    norm_v = np.linalg.norm(v)
    if norm_v > eps:
        v = v / norm_v
    return v

In [259]:
input_size = 20
output_size = 40
num_components = 2

linears_with_bias = [nn.Linear(input_size, output_size, bias=True) for _ in range(num_components)]
# Test without bias
linears_without_bias = [nn.Linear(input_size, output_size, bias=False) for _ in range(num_components)]

x = torch.rand(1, input_size)

# for _ in range(10):  # Reduced number of iterations for faster testing
weight_values = np.random.rand(num_components).tolist()
weight_values = [None] * num_components
bias_values = np.random.rand(num_components).tolist()
bias_values = [None] * num_components

# Test with bias
masked_linears = LinearsWithMasks(
    linears=linears_without_bias,
    weight_modes=["vector_input"] * num_components,
    weight_values=weight_values,
    bias_modes=["vector_input"] * num_components,
    bias_values=bias_values,
    constrain_mode = "spherical"
)

# masked_linears_ref = LinearsWithMasksRef(
#     linears=linears_without_bias,
#     weight_modes=["vector_input"] * num_components,
#     weight_values=weight_values,
#     bias_modes=["vectovector_inputr_input"] * num_components,
#     bias_values=bias_values,
# )

In [56]:
# torch.testing.assert_close(masked_linears(x), masked_linears_ref(x), atol=0, rtol=0)

In [241]:
# masked_linears(x)

In [260]:
masked_linears.get_constrained_masks()

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000])


{'weight_masks': [tensor([[0.6948, 0.6881, 0.7260, 0.6823, 0.6489, 0.7358, 0.7699, 0.7335, 0.7475,
           0.7471, 0.6594, 0.6986, 0.6972, 0.6598, 0.6798, 0.7235, 0.7175, 0.6985,
           0.6656, 0.7163]]),
  tensor([[0.6948, 0.6881, 0.7260, 0.6823, 0.6489, 0.7358, 0.7699, 0.7335, 0.7475,
           0.7471, 0.6594, 0.6986, 0.6972, 0.6598, 0.6798, 0.7235, 0.7175, 0.6985,
           0.6656, 0.7163]])],
 'bias_masks': [None, None]}

In [261]:
constrainer = Constrainer(
    component_weights=[x.weight for x in linears_without_bias],
    constrain_mode="spherical"
)

In [262]:
constrainer.theta_0s

tensor([[1.5351, 1.5148, 1.6222, 1.4966, 1.3822, 1.6472, 1.7280, 1.6415, 1.6763,
         1.6751, 1.4201, 1.5464, 1.5423, 1.4216, 1.4887, 1.6157, 1.5995, 1.5459,
         1.4420, 1.5964]], grad_fn=<AcosBackward0>)

In [266]:
constrainer.dots

tensor([[ 0.0357,  0.0559, -0.0514,  0.0741,  0.1875, -0.0764, -0.1565, -0.0706,
         -0.1053, -0.1041,  0.1501,  0.0244,  0.0285,  0.1487,  0.0820, -0.0448,
         -0.0287,  0.0249,  0.1285, -0.0256]], grad_fn=<SumBackward1>)

In [264]:
v0 = linears_without_bias[0].weight.data[:, -1]
v1 = linears_without_bias[1].weight.data[:, -1]
slerp(
    0.5, v0, v1
)

[np.float32(-0.033729836),
 np.float32(1.6045326),
 np.float32(0.8022663),
 np.float32(0.7193425),
 np.float32(0.7193425)]

In [235]:
torch.sum(v0 * v1)

tensor(0.0657)

In [223]:
linears_without_bias[1].weight.data[1], linears_without_bias[0].weight.data[1]

(tensor([ 0.0900, -0.1653,  0.2175, -0.0066, -0.1509,  0.1738, -0.2021,  0.0056,
          0.0747,  0.0987,  0.0641, -0.1328, -0.2082,  0.0532, -0.0573, -0.0543,
         -0.1366,  0.1805,  0.1399,  0.0938]),
 tensor([-0.0444,  0.1529,  0.1763, -0.1462,  0.1757,  0.0038, -0.1280, -0.0950,
          0.1156, -0.0166, -0.1705,  0.1983,  0.0292,  0.1613, -0.0939, -0.0343,
          0.0375,  0.2133,  0.0792, -0.1354]))