In [1]:
!git clone https://huggingface.co/pt-sk/m

Cloning into 'm'...
remote: Enumerating objects: 32, done.[K
remote: Counting objects: 100% (28/28), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 32 (delta 10), reused 0 (delta 0), pack-reused 4 (from 1)[K
Unpacking objects: 100% (32/32), 11.74 KiB | 1.47 MiB/s, done.


In [2]:
!pip install datasets
!pip install einops
!pip install trl
!pip install transformers
!pip install transformers[torch]
!pip install accelerate -U
!pip install fairscale

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m627.3 kB/s[0m eta [36m0:00:00[0m[36m0:00:01[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
Collecting trl
  Downloading trl-0.9.3-py3-none-any.whl.metadata (11 kB)
Collecting tyro>=0.5.11 (from trl)
  Downloading tyro-0.8.4-py3-none-any.whl.metadata (7.9 kB)
Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)
  Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)
Downloading trl-0.9.3-py3-none-any.whl (226 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.6/226.6 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m0m
[?25hDownloading tyro-0.8.4-py3-none-any.whl (102 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.4/102.4 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloadi

In [3]:
# necessary libraries
from __future__ import annotations
import math
import json
import io

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

import datasets
from dataclasses import dataclass
from datasets import load_dataset, Dataset
from einops import rearrange, repeat, einsum
from typing import Union

from transformers import AutoTokenizer

import math
import numpy as np
import pandas as pd
import warnings

import fairscale
from tqdm import tqdm
warnings.filterwarnings("ignore")

# import torch_xla
# import torch_xla.core.xla_model as xm

# Device initialization
# device = xm.xla_device()

In [4]:
"""

An implementation of the parallel scan operation in PyTorch (Blelloch version).
Please see docs/pscan.ipynb for a detailed explanation of what happens here.

"""

def npo2(len):
    """
    Returns the next power of 2 above len
    """

    return 2 ** math.ceil(math.log2(len))

def pad_npo2(X):
    """
    Pads input length dim to the next power of 2

    Args:
        X : (B, L, D, N)

    Returns:
        Y : (B, npo2(L), D, N)
    """

    len_npo2 = npo2(X.size(1))
    pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
    return F.pad(X, pad_tuple, "constant", 0)

class PScan(torch.autograd.Function):
    @staticmethod
    def pscan(A, X):
        # A : (B, D, L, N)
        # X : (B, D, L, N)

        # modifies X in place by doing a parallel scan.
        # more formally, X will be populated by these values :
        # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
        # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)

        # only supports L that is a power of two (mainly for a clearer code)
        
        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        # up sweep (last 2 steps unfolded)
        Aa = A
        Xa = X
        for _ in range(num_steps-2):
            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)
            
            Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
            Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])

            Aa = Aa[:, :, :, 1]
            Xa = Xa[:, :, :, 1]

        # we have only 4, 2 or 1 nodes left
        if Xa.size(2) == 4:
            Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
            Aa[:, :, 1].mul_(Aa[:, :, 0])

            Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
        elif Xa.size(2) == 2:
            Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
            return
        else:
            return

        # down sweep (first 2 steps unfolded)
        Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
        Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
        Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
        Aa[:, :, 2].mul_(Aa[:, :, 1])

        for k in range(num_steps-3, -1, -1):
            Aa = A[:, :, 2**k-1:L:2**k]
            Xa = X[:, :, 2**k-1:L:2**k]

            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)

            Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
            Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])

    @staticmethod
    def pscan_rev(A, X):
        # A : (B, D, L, N)
        # X : (B, D, L, N)

        # the same function as above, but in reverse
        # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
        # it is used in the backward pass

        # only supports L that is a power of two (mainly for a clearer code)

        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        # up sweep (last 2 steps unfolded)
        Aa = A
        Xa = X
        for _ in range(num_steps-2):
            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)
                    
            Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
            Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])

            Aa = Aa[:, :, :, 0]
            Xa = Xa[:, :, :, 0]

        # we have only 4, 2 or 1 nodes left
        if Xa.size(2) == 4:
            Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
            Aa[:, :, 2].mul_(Aa[:, :, 3])

            Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
        elif Xa.size(2) == 2:
            Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
            return
        else:
            return

        # down sweep (first 2 steps unfolded)
        Aa = A[:, :, 0:L:2**(num_steps-2)]
        Xa = X[:, :, 0:L:2**(num_steps-2)]
        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
        Aa[:, :, 1].mul_(Aa[:, :, 2])

        for k in range(num_steps-3, -1, -1):
            Aa = A[:, :, 0:L:2**k]
            Xa = X[:, :, 0:L:2**k]

            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)

            Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
            Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])

    @staticmethod
    def forward(ctx, A_in, X_in):
        """
        Applies the parallel scan operation, as defined above. Returns a new tensor.
        If you can, privilege sequence lengths that are powers of two.

        Args:
            A_in : (B, L, D, N)
            X_in : (B, L, D, N)

        Returns:
            H : (B, L, D, N)
        """

        L = X_in.size(1)

        # cloning is requiered because of the in-place ops
        if L == npo2(L):
            A = A_in.clone()
            X = X_in.clone()
        else:
            # pad tensors (and clone btw)
            A = pad_npo2(A_in) # (B, npo2(L), D, N)
            X = pad_npo2(X_in) # (B, npo2(L), D, N)
        
        # prepare tensors
        A = A.transpose(2, 1) # (B, D, npo2(L), N)
        X = X.transpose(2, 1) # (B, D, npo2(L), N)

        # parallel scan (modifies X in-place)
        PScan.pscan(A, X)

        ctx.save_for_backward(A_in, X)
        
        # slice [:, :L] (cut if there was padding)
        return X.transpose(2, 1)[:, :L]
    
    @staticmethod
    def backward(ctx, grad_output_in):
        """
        Flows the gradient from the output to the input. Returns two new tensors.

        Args:
            ctx : A_in : (B, L, D, N), X : (B, D, L, N)
            grad_output_in : (B, L, D, N)

        Returns:
            gradA : (B, L, D, N), gradX : (B, L, D, N)
        """

        A_in, X = ctx.saved_tensors

        L = grad_output_in.size(1)

        # cloning is requiered because of the in-place ops
        if L == npo2(L):
            grad_output = grad_output_in.clone()
            # the next padding will clone A_in
        else:
            grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
            A_in = pad_npo2(A_in) # (B, npo2(L), D, N)

        # prepare tensors
        grad_output = grad_output.transpose(2, 1)
        A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
        A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)

        # reverse parallel scan (modifies grad_output in-place)
        PScan.pscan_rev(A, grad_output)

        Q = torch.zeros_like(X)
        Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])

        return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
    
pscan = PScan.apply

In [5]:
"""Simple, minimal implementation of Mamba in one file of PyTorch.

Suggest reading the following before/while reading the code:
    [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
        https://arxiv.org/abs/2312.00752
    [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
        https://srush.github.io/annotated-s4

Glossary:
    b: batch size                       (`B` in Mamba paper [1] Algorithm 2)
    l: sequence length                  (`L` in [1] Algorithm 2)
    d or d_model: hidden dim
    n or d_state: latent state dim      (`N` in [1] Algorithm 2)
    expand: expansion factor            (`E` in [1] Section 3.4)
    d_in or d_inner: d * expand         (`D` in [1] Algorithm 2)
    A, B, C, D: state space parameters  (See any state space representation formula)
                                        (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
    Δ or delta: input-dependent step size
    dt_rank: rank of Δ                  (See [1] Section 3.6 "Parameterization of ∆")

"""
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

class Mamba(nn.Module):
    
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)

        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
#         self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper

    def forward(self):
        """
        Args:
            input_ids (long tensor): shape (b, l)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            logits: shape (b, l, vocab_size)
        """
        layers = nn.Sequential(self.embedding, *self.layers, self.norm_f, self.lm_head)
        return layers


    @staticmethod
    def from_config(pretrained_model_name: str):
      from transformers.utils import CONFIG_NAME
      from transformers.utils.hub import cached_file
      
      def load_config_hf(model_name):
          resolved_archive_file = cached_file(model_name, CONFIG_NAME,
                                              _raise_exceptions_for_missing_entries=False)
          return json.load(open(resolved_archive_file))
      config_data = load_config_hf(pretrained_model_name)
      args = ModelArgs(
          d_model=config_data['d_model'],
          n_layer=config_data['n_layer'],
          vocab_size=config_data['vocab_size']
      )
      model = Mamba(args)
      return model

    
    @staticmethod
    def from_pretrained(pretrained_model_name: str):
        """Load pretrained weights from HuggingFace into model.
    
        Args:
            pretrained_model_name: One of
                * 'state-spaces/mamba-2.8b-slimpj'
                * 'state-spaces/mamba-2.8b'
                * 'state-spaces/mamba-1.4b'
                * 'state-spaces/mamba-790m'
                * 'state-spaces/mamba-370m'
                * 'state-spaces/mamba-130m'
                            
        Returns:
            model: Mamba model with weights loaded
    
        """
        from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
        from transformers.utils.hub import cached_file
        
        def load_config_hf(model_name):
            resolved_archive_file = cached_file(model_name, CONFIG_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return json.load(open(resolved_archive_file))
        
        
        def load_state_dict_hf(model_name, device=None, dtype=None):
            resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
        
        config_data = load_config_hf(pretrained_model_name)
        args = ModelArgs(
            d_model=config_data['d_model'],
            n_layer=config_data['n_layer'],
            vocab_size=config_data['vocab_size']
        )
        model = Mamba(args)
        
        state_dict = load_state_dict_hf(pretrained_model_name)
        new_state_dict = {}
        for key in state_dict:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = state_dict[key]
        model.load_state_dict(new_state_dict)
        
        return model


class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)
        

    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        """
        output = self.mixer(self.norm(x)) + x

        return output
            

class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
        

    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)  
        """
        (b, l, d) = x.shape
        
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = F.silu(x)

        y = self.ssm(x)
        
        y = y * F.silu(res)
        
        output = self.out_proj(y)

        return output

    
    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y

    
    def selective_scan(self, x, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
    
        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)
    
        Returns:
            output: shape (b, l, d_in) 
        """
        # parallel scan
        deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
        
        hs = pscan(deltaA, BX)

        y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y


class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output        

In [6]:
# "flytech/python-codes-25k" --- 2000 epochs
# "iamtarun/python_code_instructions_18k_alpaca"----1500 epochs
# muellerzr/python-stack-v1-functions-filtered-llama-3-8B----8000 epochs
# bigcode/python-stack-v1-functions-filtered-sc2-subset------2000 epochs
# jean1/45k_python_code_chinese_instruction ---- 2500 epochs
# MohamedSaeed-dev/PythonDataV2 ---- 3000 epochs
# Vezora/Tested-143k-Python-Alpaca --- 4000 epochs -- still yet to train for more epochs

# HydraLM/instruct-python-500k-standardized --- 2000 epochs-----trained for the first 100k
# "AayushMathur/manim_python_alpaca"
# "gauravvaid/python-code_samples"
# Fraser/python-state-changes
# "mengmengmmm/csn_python_trainuse"

In [6]:
ds = load_dataset("Vezora/Tested-143k-Python-Alpaca")
ds

Downloading readme:   0%|          | 0.00/511 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/164M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/180M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/185M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1002698 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['message', 'message_type', 'message_id', 'conversation_id'],
        num_rows: 1002698
    })
})

In [7]:
# changing the format of the dataset
ds.set_format(type="pandas")

# taking the train split
df = ds["train"][:]
df.head()

Unnamed: 0,message,message_type,message_id,conversation_id
0,"What does the ""yield"" keyword do?: What is the...",instruction,0,0
1,"To understand what yield does, you must unders...",output,1,0
2,What is a metaclass in Python?: What are metac...,instruction,0,1
3,Classes as objects\nBefore understanding metac...,output,1,1
4,How to make a chain of function decorators in ...,instruction,0,2


In [8]:
# df.shape

In [9]:
# df["text"] = df["start"] + ". " + df["code"] + ". " + df["end"]
# df.head()

In [10]:
text = ". [EOS] ".join(df["message"].head(100000))

In [11]:
# # converting the pandas dataset to hugging face dataset format
# small_ds = Dataset.from_pandas(df)
# print(small_ds)

# dataloader = DataLoader(small_ds["text"], batch_size=1, shuffle=True, pin_memory=True)
# print(len(dataloader))

In [12]:
file_path = "/kaggle/working/m/mamba_python-2000_epoch_stage5-3.pt"
state = torch.load(file_path, map_location="cpu")

In [13]:
# loading the tokenizer
tokenizer = AutoTokenizer.from_pretrained('pt-sk/mamba')

# loading the model from 
mamba_model = Mamba.from_config("pt-sk/mamba_python")

# pipeline to distributed model training
model = fairscale.nn.Pipe(mamba_model(), balance=[15, 12], chunks = 4)

# loading the model
model.load_state_dict(state["model_state_dict"])

# model = fairscale.nn.Pipe(mamba_model(), balance=[15, 12], chunks = 4, devices = ["xla:0", "xla:1"])

# loading the optimizer
optimizer = AdamW(model.parameters(), lr=0.00001)

# loading optimizer weights
# optimizer.load_state_dict(state["optimizer_state_dict"])

tokenizer_config.json:   0%|          | 0.00/5.33k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/799k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/457k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/563 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/199 [00:00<?, ?B/s]

In [14]:
optimizer.load_state_dict(state["optimizer_state_dict"])

In [15]:
tokens = tokenizer(text, return_tensors="pt").input_ids.squeeze(0)

block_size = 512
batch_size = 8

len_tokens = len(tokens)
def get_batch():
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len_tokens - block_size, (batch_size,))
    x = torch.stack([tokens[i:i+block_size] for i in ix])
    return x

In [16]:
# state = torch.load("/kaggle/working/mamba_python/mamba_python_1.pt", map_location="cpu")
# mamba_model.load_state_dict(state["model_state_dict"])
# optimizer.load_state_dict(state["optimizer_state_dict"])

In [17]:
# Trainer
epochs = 2000
iterator = tqdm(range(epochs), desc="Training", postfix={"train_loss": 0.0})

for epoch in iterator:
  
  encoded_inp = get_batch()
  logits = model(encoded_inp.to("cuda:0"))

  labels = encoded_inp.to("cuda:1")
  
  shift_logits = logits[:, :-1, :].contiguous()
  labels = labels[:, 1:].contiguous()
  loss_fct = torch.nn.CrossEntropyLoss()
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
  
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  

  # moving data's from gpu to cpu
  loss = loss.to("cpu")
  logits = logits.to("cpu")
  labels = labels.to("cpu")
  encoded_inp = encoded_inp.to("cpu")
  shift_logits = shift_logits.to("cpu")
  
  iterator.set_postfix({"train_loss": loss.item()}, refresh=False)

Training: 100%|██████████| 2000/2000 [2:28:29<00:00,  4.45s/it, train_loss=3.85]  


In [18]:
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict()
}, "mamba_python-standardized100k2000epoch-stage6.pt")

In [19]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [21]:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
    path_or_fileobj="/kaggle/working/mamba_python-standardized100k2000epoch-stage6.pt",
    path_in_repo="mamba_python-standardized100k2000epoch-stage6.pt",
    repo_id="pt-sk/m",
    repo_type="model",
)

mamba_python-standardized100k2000epoch-stage6.pt:   0%|          | 0.00/2.01G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/pt-sk/m/commit/53de98b25ccfc4a163499cf71f50bd8b6dec9133', commit_message='Upload mamba_python-standardized100k2000epoch-stage6.pt with huggingface_hub', commit_description='', oid='53de98b25ccfc4a163499cf71f50bd8b6dec9133', pr_url=None, pr_revision=None, pr_num=None)

In [20]:
# cleaning the gpu memory
import gc

def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

if torch.cuda.is_available():
    clear_gpu_memory()
    print("GPU memory cleared.")
else:
    print("CUDA is not available.")

GPU memory cleared.


In [None]:
# import torch
# import gc

# def get_gpu_memory_usage():
#     allocated_memory = torch.cuda.memory_allocated()
#     reserved_memory = torch.cuda.memory_reserved()
#     return allocated_memory, reserved_memory

# def list_gpu_variables():
#     for obj in gc.get_objects():
#         try:
#             if torch.is_tensor(obj) and obj.is_cuda:
#                 print(f"Tensor on GPU: {obj}, Size: {obj.size()}, Memory: {obj.element_size() * obj.nelement()}")
#         except Exception as e:
#             pass

# if torch.cuda.is_available():
#     allocated, reserved = get_gpu_memory_usage()
#     print(f"Allocated GPU memory: {allocated} bytes")
#     print(f"Reserved GPU memory: {reserved} bytes")
#     list_gpu_variables()
# else:
#     print("CUDA is not available.")

In [33]:
import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 100,
             sample: bool = True,
             top_k: int = 40):
    model.eval()

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda:0")
    print(input_ids)

    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids.to("cuda:0")
            next_token_logits = model(indices_to_input)[:, -1].to("cuda:0")

        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape

        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)

        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]

        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]

    return output_completions

In [41]:
generate(model, tokenizer, "What is a metaclass in Python?")

tensor([[ 1276,   310,   247,  1313,   317, 14407,   275, 13814,    32]],
       device='cuda:0')


'What is a metaclass in Python? I am trying to find an image that doesn\'t think the way of using the list and the __all__ method and the first is the "cac" list containing the __iter__ method. Any you have a dictionary:\n>>> d = {k: k for k, v in d.items() if v == k.values_count}\n[1, 2, 3],\n \'b\': [1,2,3,4,5,6,7] else'