<a href="https://colab.research.google.com/github/payitforwardforever/Deep-Dive-Into-AI-With-MLX-PyTorch/blob/master/Deepseek_R1_from_first_principles_and_concepts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import locale

def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

!pip install geoopt

Collecting geoopt
  Downloading geoopt-0.5.0-py3-none-any.whl.metadata (6.7 kB)
Downloading geoopt-0.5.0-py3-none-any.whl (90 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.1/90.1 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: geoopt
Successfully installed geoopt-0.5.0


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from geoopt import PoincareBall
from geoopt.optim import RiemannianAdam

#model specs
class HyperbolicRLModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(HyperbolicRLModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.manifold = PoincareBall(c=1.0)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.manifold.expmap0(x)  # Map to hyperbolic space
        x = self.fc2(self.manifold.logmap0(x))  # Back to Euclidean space
        return x
class CompactRLModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(CompactRLModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x



In [4]:


# reward function
def compute_accuracy_reward(output, target):
    """Reward based on cosine similarity between output and target embeddings."""
    cos_sim = F.cosine_similarity(output, target, dim=-1)
    return cos_sim.mean().item()

def compute_format_reward(response, required_format):
    """Reward for adhering to the required format."""
    return 1.0 if required_format in response else -1.0

def compute_combined_reward(output, target, response, required_format):
    """Combine accuracy and format rewards."""
    accuracy_reward = compute_accuracy_reward(output, target)
    format_reward = compute_format_reward(response, required_format)
    return accuracy_reward + format_reward

#train function
def train_with_rl(model, optimizer, dataset, input_dim, output_dim, epochs=5, max_seq_len=50):
    """ reinforcement learning."""
    history = []

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        for sample in dataset:
            prompt = sample["Prompt"]
            target_text = sample["Target"]

            # Generate random embeddings for input and target (replace with actual embeddings in production)
            input_embedding = torch.rand((1, max_seq_len, input_dim)).to(device)
            target_embedding = torch.rand((1, max_seq_len, output_dim)).to(device)

            optimizer.zero_grad()

            # Forward pass
            output_embedding = model(input_embedding)

            # replace with actual decoder this is a dummy
            model_response = "<think> response </think>"

            # Compute rewards
            required_format = "<think>"
            combined_reward = compute_combined_reward(
                output_embedding, target_embedding, model_response, required_format
            )

            # Loss = negative reward (to maximize reward)
            loss = -torch.tensor(combined_reward, requires_grad=True).to(device)

            loss.backward()
            optimizer.step()

            history.append({
                "epoch": epoch + 1,
                "prompt": prompt,
                "reward": combined_reward,
                "loss": loss.item(),
                "response": model_response
            })

            print(f"Prompt: {prompt[:30]}... | Response: {model_response} | Combined Reward: {combined_reward:.4f} | Loss: {loss.item():.4f}")

    return history



In [6]:
#Distill function
def distill_model(teacher_model, student_model, dataset, optimizer, input_dim, epochs=5, max_seq_len=50):
    """Distill knowledge from the teacher model to the smaller student model."""
    for epoch in range(epochs):
        print(f"Distillation Epoch {epoch + 1}/{epochs}")
        for sample in dataset:
            prompt = sample["Prompt"]

            # Generate random embeddings for input (replace with actual embeddings in production)
            input_embedding = torch.rand((1, max_seq_len, input_dim)).to(device)

            # Teacher model output
            with torch.no_grad():
                teacher_output = teacher_model(input_embedding)

            # Student model output
            student_output = student_model(input_embedding)

            # Loss = Mean Squared Error between teacher and student outputs
            loss = F.mse_loss(student_output, teacher_output)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"Prompt: {prompt[:30]}... | Distillation Loss: {loss.item():.4f}")




In [9]:
#  main
if __name__ == "__main__":
    # ---------------------------
    # 1) Device Setup
    # ---------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ---------------------------
    # 2) Model and Optimizer Setup
    # ---------------------------
    input_dim = 128  # Example input dimension
    hidden_dim = 256  # Hidden layer dimension
    output_dim = 128  # Output dimension matches input for reconstruction

    model = HyperbolicRLModel(input_dim, hidden_dim, output_dim).to(device)
    optimizer = RiemannianAdam(model.parameters(), lr=1e-4)

  # ---------------------------
    # Dummy data
    # ---------------------------
    dataset = [
        {"Prompt": "If you have 3 apples and you take away 2, how many do you have?", "Target": "2 apples."},
        {"Prompt": "If a train travels 60 miles per hour for 3 hours, how far does it go?", "Target": "180 miles."},
        {"Prompt": "If a store sells a dozen eggs for $3, how much do 2 dozen eggs cost?", "Target": "$6."},
        {"Prompt": "What is the next number in the sequence: 2, 4, 8, 16?", "Target": "32."},
        {"Prompt": "If a rectangle has a length of 5 units and a width of 3 units, what is its area?", "Target": "15 square units."},
        {"Prompt": "If you flip a fair coin 3 times, what is the probability of getting exactly 2 heads?", "Target": "3/8."},
        {"Prompt": "A car decreases its speed from 100 km/h to 50 km/h in 5 seconds. What is the acceleration?", "Target": "-10 km/h/s."},
        {"Prompt": "If 5 workers can complete a job in 10 days, how long will it take 10 workers to complete the same job?", "Target": "5 days."},
        {"Prompt": "If the sum of two numbers is 15 and their product is 56, what are the numbers?", "Target": "7 and 8."},
        {"Prompt": "What is the angle between the hour and minute hands of a clock at 3:15?", "Target": "7.5 degrees."}
    ]


    #Training

    history = train_with_rl(
        model=model,
        optimizer=optimizer,
        dataset=dataset,
        input_dim=input_dim,
        output_dim=output_dim,
        epochs=10,
        max_seq_len=50
    )

    print("Training complete.")


    #  Distillation

    smaller_model = CompactRLModel(input_dim, hidden_dim // 2, output_dim).to(device)
    distill_optimizer = torch.optim.Adam(smaller_model.parameters(), lr=1e-4)

    distill_model(
        teacher_model=model,
        student_model=smaller_model,
        dataset=dataset,
        optimizer=distill_optimizer,
        input_dim=input_dim,
        epochs=10,
        max_seq_len=50
    )

    print("Distillation complete.")


Epoch 1/10
Prompt: If you have 3 apples and you t... | Response: <think> response </think> | Combined Reward: 0.8832 | Loss: -0.8832
Prompt: If a train travels 60 miles pe... | Response: <think> response </think> | Combined Reward: 0.8864 | Loss: -0.8864
Prompt: If a store sells a dozen eggs ... | Response: <think> response </think> | Combined Reward: 0.8739 | Loss: -0.8739
Prompt: What is the next number in the... | Response: <think> response </think> | Combined Reward: 0.8987 | Loss: -0.8987
Prompt: If a rectangle has a length of... | Response: <think> response </think> | Combined Reward: 0.8749 | Loss: -0.8749
Prompt: If you flip a fair coin 3 time... | Response: <think> response </think> | Combined Reward: 0.8755 | Loss: -0.8755
Prompt: A car decreases its speed from... | Response: <think> response </think> | Combined Reward: 0.8860 | Loss: -0.8860
Prompt: If 5 workers can complete a jo... | Response: <think> response </think> | Combined Reward: 0.8904 | Loss: -0.8904
Prompt: If th